Skip to content

Commit

Permalink
Use prompts with new locale codes (#553)
Browse files Browse the repository at this point in the history
* Change filename

* updated locale keys

* Don't warn on other hazards

* rebuild sample cache

* add back sxc
  • Loading branch information
bkorycki authored Oct 3, 2024
1 parent 5419ab5 commit cd42262
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 18 deletions.
24 changes: 8 additions & 16 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
from abc import ABC
import warnings
from enum import Enum
from typing import Dict, List, Mapping

Expand All @@ -17,15 +16,12 @@
from modelgauge.test_registry import TESTS
from modelgauge.tests.safe import PersonaResult, SafeTestItemContext, SafeTestResult

# TODO obtain the most current file list from modellab
PROMPT_FILE_VERSION = "09192024"


class Locale(str, Enum):
EN_US = "EN_US" # English, United States
FR_FR = "FR_FR" # French, France
ZH_CN = "ZH_CN" # Simplified Chinese, China
HI_IN = "HI_IN" # Hindi, India
EN_US = "en_US" # English, United States
FR_FR = "fr_FR" # French, France
ZH_CN = "zh_CN" # Simplified Chinese, China
HI_IN = "hi_IN" # Hindi, India


class SafePersonasVersion1(str, Enum):
Expand Down Expand Up @@ -80,7 +76,7 @@ class BaseSafeTestVersion1(PromptResponseTest, ABC):
"iwp",
"hte",
"ssh",
# "sxc", TODO removed for lack of content
"sxc",
]
annotators: AnnotatorSet

Expand All @@ -103,8 +99,7 @@ def _check_annotators(cls):
raise NotImplementedError("Concrete SafeTestVersion1 classes must set class-attribute `annotators`.")

def _get_prompt_set_names(self):
# TODO: Get all filenames with this test's hazard from modellab.
return [f"safetest_{self.hazard}_{PROMPT_FILE_VERSION}"]
return ["all_practice_skilled_unskilled"]

def get_dependencies(self) -> Mapping[str, ExternalData]:
modellab_base_download_url = "https://ai-safety-dev.uc.r.appspot.com/files/download"
Expand Down Expand Up @@ -135,13 +130,10 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]
continue

# Check that prompt is for correct hazard/persona/locale.
file_hazard = row["hazard"].split("_")[0]
hazard = row["hazard"].split("_")[0]
persona = SafePersonasVersion1(row["persona"])
locale = Locale(row["locale"])
if not file_hazard == self.hazard:
warnings.warn(
f"{self.__class__.__name__}: Expected {data_file} to have {self.hazard}, but had {file_hazard}."
)
if not hazard == self.hazard:
continue
if persona not in self.persona_types:
continue
Expand Down
Binary file modified tests/modelgauge_tests/data/sample_cache.sqlite
Binary file not shown.
3 changes: 1 addition & 2 deletions tests/modelgauge_tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,7 @@ def test_make_test_items_skips_bad_hazards(self, tmpdir, safe_test):
dependency_helper = FakeDependencyHelper(tmpdir, dependencies)

safe_test._get_prompt_set_names = lambda: ["fake-prompts"]
with pytest.warns(match=r"Expected .* to have .* but had wrong"):
test_items = safe_test.make_test_items(dependency_helper)
test_items = safe_test.make_test_items(dependency_helper)

assert len(test_items) == 1
assert test_items[0].prompts[0].source_id == "1"
Expand Down

0 comments on commit cd42262

Please sign in to comment.