-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rename helm module to helm_runner to prevent clobbering the helm library in the namespace; add helm_interface which includes an InProcessHelmRunner * run black * sorting/optimizing imports * import from helm_runner
- Loading branch information
1 parent
d3deae2
commit 2c4e01d
Showing
11 changed files
with
185 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import subprocess | ||
from typing import Iterable | ||
|
||
import helm.benchmark.run_specs | ||
from helm.benchmark.config_registry import ( | ||
register_builtin_configs_from_helm_package, | ||
) | ||
from helm.benchmark.executor import ExecutionSpec | ||
from helm.benchmark.huggingface_registration import ( | ||
register_huggingface_hub_model_from_flag_value, | ||
) | ||
from helm.benchmark.model_deployment_registry import ( | ||
ClientSpec, | ||
ModelDeployment, | ||
register_model_deployment, | ||
) | ||
from helm.benchmark.presentation.run_entry import RunEntry | ||
from helm.benchmark.run import run_entries_to_run_specs | ||
from helm.benchmark.runner import Runner | ||
from helm.common.authentication import Authentication | ||
|
||
from coffee.helm_runner import HelmResult, HelmRunner, HelmSut, HelmTest | ||
|
||
helm.benchmark.run_specs.INCLUDE_GENERATIVE_HARMS_METRICS = True | ||
|
||
|
||
class InProcessHelmRunner(HelmRunner): | ||
def run(self, tests: list[HelmTest], suts: list[HelmSut], max_instances=10): | ||
self._execute( | ||
tests, | ||
suts, | ||
max_eval_instances=max_instances, | ||
suite="v1", | ||
num_threads=4, | ||
benchmark_output_path="run/benchmark_output", | ||
prod_env_path="run/prod_env", | ||
) | ||
|
||
output_dir = self._make_output_dir() | ||
|
||
# THIS IS A BIG, DUMB HACK until we unwind subprocess.CompletedProcess from the run mix. | ||
execution_result = subprocess.run("", shell=True, capture_output=True, cwd=output_dir) | ||
# END BIG DUMB HACK | ||
|
||
return HelmResult(tests, suts, output_dir, execution_result) | ||
|
||
def _execute( | ||
self, | ||
tests: Iterable["HelmTest"], | ||
suts: Iterable["HelmSut"], | ||
max_eval_instances: int = 10, | ||
suite: str = "v1", | ||
num_threads: int = 1, | ||
benchmark_output_path: str = "run/benchmark_output", | ||
prod_env_path: str = "run/prod_env", | ||
) -> None: | ||
register_builtin_configs_from_helm_package() | ||
for sut in suts: | ||
if sut.huggingface: | ||
register_huggingface_hub_model_from_flag_value(sut.key) | ||
model_deployment = ModelDeployment( | ||
name=sut.key, | ||
tokenizer_name=sut.tokenizer_name, | ||
max_sequence_length=sut.tokenizer_max_length, | ||
client_spec=ClientSpec(class_name="helm.proxy.clients.huggingface_client.HuggingFaceClient"), | ||
) | ||
register_model_deployment(model_deployment) | ||
run_entries = [RunEntry(r, 1, list()) for r in self._build_runspecs(suts, tests)] | ||
run_specs = run_entries_to_run_specs(run_entries, max_eval_instances=max_eval_instances) | ||
execution_spec = ExecutionSpec( | ||
url=None, | ||
auth=Authentication(""), | ||
local_path=prod_env_path, | ||
parallelism=num_threads, | ||
) | ||
runner = Runner( | ||
execution_spec=execution_spec, | ||
output_path=benchmark_output_path, | ||
suite=suite, | ||
skip_instances=False, | ||
cache_instances=False, | ||
cache_instances_only=False, | ||
skip_completed_runs=False, | ||
exit_on_error=False, | ||
) | ||
runner.run_all(run_specs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
|
||
from coffee import helm_interface | ||
from coffee.helm_runner import BbqHelmTest, HelmSut | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def monkeypatch_run_all(monkeypatch): | ||
mock_obj = MagicMock() | ||
monkeypatch.setattr(helm_interface.Runner, "run_all", mock_obj) | ||
return mock_obj | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def monkeypatch_run_one(monkeypatch): | ||
mock_obj = MagicMock() | ||
monkeypatch.setattr(helm_interface.Runner, "run_one", mock_obj) | ||
return mock_obj | ||
|
||
|
||
@pytest.fixture | ||
def monkeypatch_register_huggingface(monkeypatch): | ||
mock_obj = MagicMock() | ||
monkeypatch.setattr( | ||
helm_interface, | ||
"register_huggingface_hub_model_from_flag_value", | ||
mock_obj, | ||
) | ||
return mock_obj | ||
|
||
|
||
@pytest.fixture | ||
def monkeypatch_run_entries_to_run_specs(monkeypatch): | ||
mock_obj = MagicMock() | ||
monkeypatch.setattr(helm_interface, "run_entries_to_run_specs", mock_obj) | ||
return mock_obj | ||
|
||
|
||
def test_run_executions_registers_huggingface( | ||
monkeypatch, monkeypatch_register_huggingface, monkeypatch_run_entries_to_run_specs | ||
): | ||
# have to monkeypatch run_entries_to_runspecs since we can't register due to monkeypatching | ||
# register_huggingface_hub_model_from_flag_value | ||
runner = helm_interface.InProcessHelmRunner() | ||
|
||
runner.run([BbqHelmTest()], [HelmSut.FB_OPT_125M, HelmSut.GPT2]) | ||
monkeypatch_register_huggingface.assert_called_once_with("facebook/opt-125m") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"tests, suts, expected", | ||
[ | ||
([BbqHelmTest()], [HelmSut.FB_OPT_125M, HelmSut.GPT2], 20), | ||
([BbqHelmTest()], [HelmSut.GPT2], 10), | ||
], | ||
) | ||
def test_generates_correct_number_runspecs(monkeypatch, monkeypatch_run_entries_to_run_specs, tests, suts, expected): | ||
runner = helm_interface.InProcessHelmRunner() | ||
|
||
runner.run(tests, suts) | ||
assert len(monkeypatch_run_entries_to_run_specs.call_args[0][0]) == expected | ||
|
||
|
||
def test_runs_run_all(monkeypatch, monkeypatch_run_all): | ||
runner = helm_interface.InProcessHelmRunner() | ||
|
||
runner.run([BbqHelmTest()], [HelmSut.GPT2]) | ||
monkeypatch_run_all.assert_called_once() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters