diff --git a/src/modelgauge/private_ensemble_annotator_set.py b/src/modelgauge/private_ensemble_annotator_set.py index c6f0f7e1..11d9e83e 100644 --- a/src/modelgauge/private_ensemble_annotator_set.py +++ b/src/modelgauge/private_ensemble_annotator_set.py @@ -1,12 +1,10 @@ import os from typing import Any, Dict, List -from modelgauge.annotator_registry import ANNOTATORS from modelgauge.annotator_set import AnnotatorSet from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.auth.together_key import TogetherApiKey from modelgauge.auth.vllm_key import VllmApiKey -from modelgauge.config import load_secrets_from_config from modelgauge.dependency_injection import _replace_with_injected from modelgauge.secret_values import InjectSecret from modelgauge.single_turn_prompt_response import TestItemAnnotations @@ -49,7 +47,6 @@ class EnsembleAnnotatorSet(AnnotatorSet): def __init__(self, secrets): self.secrets = secrets - # TODO: Pass in the strategy as a parameter for easy swapping. self.strategy = MajorityVoteEnsembleStrategy() self.__configure_vllm_annotators() self.__configure_huggingface_annotators() @@ -57,14 +54,6 @@ def __init__(self, secrets): self._register_annotators() def _register_annotators(self): - # TODO: Register annotators in secret repo. - # ANNOTATORS.register( - # PromptEngineeredAnnotator(uid=MISTRAL_8x22B_CONFIG.name, config=MISTRAL_8x22B_CONFIG), - # ) - # ANNOTATORS.register(PromptEngineeredAnnotator(uid=LLAMA_3_70B_CONFIG.name, config=LLAMA_3_70B_CONFIG)) - # ANNOTATORS.register(LlamaGuard2LoRAAnnotator(uid=LG2_LORA_CONFIG.name, config=LG2_LORA_CONFIG)) - # ANNOTATORS.register(WildguardAnnotator(uid=WILDGUARD_ANNOTATOR_CONFIG.name, config=WILDGUARD_ANNOTATOR_CONFIG)) - self.annotators = [ MISTRAL_8x22B_CONFIG.name, LLAMA_3_70B_CONFIG.name, diff --git a/tests/modelgauge_tests/test_private_ensemble_annotator_set.py b/tests/modelgauge_tests/test_private_ensemble_annotator_set.py index b101b43c..5992d8d2 100644 --- a/tests/modelgauge_tests/test_private_ensemble_annotator_set.py +++ b/tests/modelgauge_tests/test_private_ensemble_annotator_set.py @@ -1,42 +1,27 @@ import os -from unittest.mock import Mock, patch - -import pytest - -from modelgauge.suts.together_client import TogetherApiKey +from unittest import mock +@mock.patch.dict(os.environ, {"VLLM_ENDPOINT_URL": "https://example.org/"}, clear=True) def test_can_load(): + """This just makes sure things are properly connected. Fuller testing is in the private code.""" try: - # EnsembleAnnotator is required by the private annotators - # If we can import it, then the EnsembleAnnotatorSet can be instantiated + from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken + from modelgauge.auth.vllm_key import VllmApiKey from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet + from modelgauge.suts.together_client import TogetherApiKey + from modelgauge.annotators.wildguard_annotator import WildguardAnnotator - assert True + in_private = True except: - # The EnsembleAnnotator can't be implemented, so the EnsembleAnnotatorSet can't either - with pytest.raises(NotImplementedError): - from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet + in_private = False - -def test_annotators(): - try: - from modelgauge.private_ensemble_annotator_set import ( - EnsembleAnnotatorSet, - HuggingFaceKey, - VllmApiKey, - ) - - os.environ["VLLM_ENDPOINT_URL"] = "fake" - annotators = EnsembleAnnotatorSet( + if in_private: + annotator_set = EnsembleAnnotatorSet( secrets={ "together_api_key": TogetherApiKey("fake"), - "huggingface_key": HuggingFaceKey("fake"), + "huggingface_inference_token": HuggingFaceInferenceToken("fake"), "vllm_api_key": VllmApiKey("fake"), } ) - assert len(annotators.annotators) == 4 - except: - # The EnsembleAnnotator can't be implemented, so the EnsembleAnnotatorSet can't either - with pytest.raises(NotImplementedError): - from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet + assert len(annotator_set.annotators) == 4