-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial helm interface and in-process runner; remove assumptions that…
… a runner will use subprocess; stop clobbering helm package namespace; tests
- Loading branch information
1 parent
327e01a
commit 0619e4b
Showing
9 changed files
with
178 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from itertools import product | ||
from typing import Iterable, TYPE_CHECKING | ||
|
||
from helm.benchmark.config_registry import ( | ||
register_builtin_configs_from_helm_package, | ||
) | ||
from helm.benchmark.executor import ExecutionSpec | ||
from helm.benchmark.huggingface_registration import ( | ||
register_huggingface_hub_model_from_flag_value, | ||
) | ||
from helm.benchmark.presentation.run_entry import RunEntry | ||
from helm.benchmark.run import run_entries_to_run_specs | ||
from helm.benchmark.runner import Runner | ||
from helm.common.authentication import Authentication | ||
|
||
if TYPE_CHECKING: | ||
from helm_runner import HelmSut, HelmTest | ||
|
||
from helm.benchmark.runner import RunnerError | ||
|
||
|
||
def run_executions( | ||
tests: Iterable["HelmTest"], | ||
suts: Iterable["HelmSut"], | ||
max_eval_instances: int = 10, | ||
suite: str = "v1", | ||
num_threads: int = 4, | ||
benchmark_output_path: str = "run/benchmark_output", | ||
prod_env_path: str = "run/prod_env", | ||
) -> None: | ||
register_builtin_configs_from_helm_package() | ||
for sut in suts: | ||
if sut.huggingface: | ||
register_huggingface_hub_model_from_flag_value(sut.key) | ||
run_entries = [] | ||
for test, sut in product(tests, suts): | ||
for runspec in test.runspecs(): | ||
run_entries.append( | ||
RunEntry( | ||
description=f"{runspec},model={sut.key}", priority=1, groups=[] | ||
) | ||
) | ||
run_specs = run_entries_to_run_specs( | ||
run_entries, max_eval_instances=max_eval_instances | ||
) | ||
execution_spec = ExecutionSpec( | ||
url=None, | ||
auth=Authentication(""), | ||
local_path=prod_env_path, | ||
parallelism=num_threads, | ||
) | ||
runner = Runner( | ||
execution_spec=execution_spec, | ||
output_path=benchmark_output_path, | ||
suite=suite, | ||
skip_instances=False, | ||
cache_instances=False, | ||
cache_instances_only=False, | ||
skip_completed_runs=False, | ||
exit_on_error=False, | ||
) | ||
runner.run_all(run_specs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
|
||
from coffee import helm_interface | ||
from coffee.helm_runner import BbqHelmTest, HelmSut | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def monkeypatch_runner(monkeypatch): | ||
mock_obj = MagicMock() | ||
monkeypatch.setattr(helm_interface, "Runner", mock_obj) | ||
return mock_obj | ||
|
||
|
||
@pytest.fixture | ||
def monkeypatch_register_huggingface(monkeypatch): | ||
mock_obj = MagicMock() | ||
monkeypatch.setattr( | ||
helm_interface, | ||
"register_huggingface_hub_model_from_flag_value", | ||
mock_obj, | ||
) | ||
return mock_obj | ||
|
||
|
||
@pytest.fixture | ||
def monkeypatch_run_entries_to_run_specs(monkeypatch): | ||
mock_obj = MagicMock() | ||
monkeypatch.setattr(helm_interface, "run_entries_to_run_specs", mock_obj) | ||
return mock_obj | ||
|
||
|
||
def test_run_executions_registers_huggingface( | ||
monkeypatch, monkeypatch_register_huggingface, monkeypatch_run_entries_to_run_specs | ||
): | ||
# have to monkeypatch run_entries_to_runspecs since we can't register due to monkeypatching | ||
# register_huggingface_hub_model_from_flag_value | ||
|
||
helm_interface.run_executions([BbqHelmTest()], [HelmSut.FB_OPT_125M, HelmSut.GPT2]) | ||
monkeypatch_register_huggingface.assert_called_once_with("facebook/opt-125m") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"tests, suts, expected", | ||
[ | ||
([BbqHelmTest()], [HelmSut.FB_OPT_125M, HelmSut.GPT2], 20), | ||
([BbqHelmTest()], [HelmSut.GPT2], 10), | ||
], | ||
) | ||
def test_generates_correct_number_runspecs( | ||
monkeypatch, monkeypatch_run_entries_to_run_specs, tests, suts, expected | ||
): | ||
helm_interface.run_executions(tests, suts) | ||
assert len(monkeypatch_run_entries_to_run_specs.call_args[0][0]) == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters