Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply ensemble updates nov 13 #691

Merged
merged 5 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions src/modelgauge/auth/vllm_key.py

This file was deleted.

41 changes: 41 additions & 0 deletions src/modelgauge/auth/vllm_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from modelgauge.secret_values import RequiredSecret, SecretDescription


class Lg3LoraVllmApiKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="lg3_lora_vllm",
key="api_key",
instructions="Contact MLCommons admin for access.",
)


class Lg3LoraVllmEndpointUrl(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="lg3_lora_vllm",
key="endpoint_url",
instructions="Contact MLCommons admin for access.",
)


class Mistral7bVllmApiKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="mistral_7b_vllm",
key="api_key",
instructions="Contact MLCommons admin for access.",
)


class Mistral7bVllmEndpointUrl(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="mistral_7b_vllm",
key="endpoint_url",
instructions="Contact MLCommons admin for access.",
)
34 changes: 19 additions & 15 deletions src/modelgauge/private_ensemble_annotator_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from typing import Any, Dict, List

from modelgauge.annotator_set import AnnotatorSet
from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.together_key import TogetherApiKey
from modelgauge.auth.vllm_key import VllmApiKey
from modelgauge.auth.vllm_keys import (
Lg3LoraVllmApiKey,
Lg3LoraVllmEndpointUrl,
Mistral7bVllmApiKey,
Mistral7bVllmEndpointUrl,
)
from modelgauge.dependency_injection import _replace_with_injected
from modelgauge.secret_values import InjectSecret
from modelgauge.single_turn_prompt_response import TestItemAnnotations
Expand All @@ -19,16 +23,15 @@
from modelgauge.annotators.template_lg3_lora_annotator import (
config as LG3_LORA_CONFIG,
) # type: ignore
from modelgauge.annotators.mistral_7b_ruby_annotator import (
config as MISTRAL_7B_RUBY_CONFIG,
) # type: ignore
from modelgauge.annotators.mistral_8x22b_instruct_annotator import (
MISTRAL_8x22B_PE_TAMALPAIS_2024_09_09_CONFIG as MISTRAL_8x22B_CONFIG,
) # type: ignore
from modelgauge.annotators.prompt_engineered_annotator import (
PromptEngineeredAnnotator,
) # type: ignore
from modelgauge.annotators.wildguard_annotator import ( # type: ignore
WILDGUARD_ANNOTATOR_CONFIG,
WildguardAnnotator,
)
from modelgauge.ensemble.majority_vote_ensemble_strategy import MajorityVoteEnsembleStrategy # type: ignore
from modelgauge.safety_model_response import SafetyModelResponse # type: ignore
except:
Expand All @@ -41,8 +44,10 @@ class EnsembleAnnotatorSet(AnnotatorSet):

required_secret_types = {
"together_api_key": TogetherApiKey,
"huggingface_inference_token": HuggingFaceInferenceToken,
"vllm_api_key": VllmApiKey,
"lg3_vllm_api_key": Lg3LoraVllmApiKey,
"lg3_vllm_endpoint_url": Lg3LoraVllmEndpointUrl,
"mistral_7b_vllm_api_key": Mistral7bVllmApiKey,
"mistral_7b_vllm_endpoint_url": Mistral7bVllmEndpointUrl,
}

def __init__(self, secrets):
Expand All @@ -55,21 +60,20 @@ def __init__(self, secrets):

def _register_annotators(self):
self.annotators = [
MISTRAL_7B_RUBY_CONFIG.name,
MISTRAL_8x22B_CONFIG.name,
LLAMA_3_70B_CONFIG.name,
LG3_LORA_CONFIG.name,
WILDGUARD_ANNOTATOR_CONFIG.name,
]

def __configure_vllm_annotators(self):
self.secrets["vllm_endpoint_url"] = os.environ.get("VLLM_ENDPOINT_URL", "")
assert self.secrets["vllm_endpoint_url"], "Environment variable `VLLM_ENDPOINT_URL` is not set."
LG3_LORA_CONFIG.api_key = self.secrets["vllm_api_key"].value
LG3_LORA_CONFIG.base_url = self.secrets["vllm_endpoint_url"]
LG3_LORA_CONFIG.api_key = self.secrets["lg3_vllm_api_key"].value
LG3_LORA_CONFIG.base_url = self.secrets["lg3_vllm_endpoint_url"].value
MISTRAL_7B_RUBY_CONFIG.api_key = self.secrets["mistral_7b_vllm_api_key"].value
MISTRAL_7B_RUBY_CONFIG.base_url = self.secrets["mistral_7b_vllm_endpoint_url"].value

def __configure_huggingface_annotators(self):
WILDGUARD_ANNOTATOR_CONFIG.token = self.secrets["huggingface_inference_token"].value
assert WILDGUARD_ANNOTATOR_CONFIG.is_valid(), "HuggingFace configuration is missing a token or endpoint URL."
return

def __configure_together_annotators(self):
MISTRAL_8x22B_CONFIG.llm_config.api_key = self.secrets["together_api_key"]
Expand Down
12 changes: 0 additions & 12 deletions tests/modelgauge_tests/test_safe.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
import pytest

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.auth.together_key import TogetherApiKey
from modelgauge.auth.vllm_key import VllmApiKey

try:
from modelgauge.private_ensemble_annotator_set import EnsembleAnnotatorSet

FAKE_HF_TOKEN = HuggingFaceInferenceToken("fake-hf-token")
FAKE_VLLM_KEY = VllmApiKey("fake-vllm-key")
except:
FAKE_HF_TOKEN = None # type: ignore
FAKE_VLLM_KEY = None # type: ignore
pass
from modelgauge.prompt import TextPrompt
from modelgauge.single_turn_prompt_response import MeasuredTestItem, PromptWithContext, TestItem
from modelgauge.tests.safe import (
Expand Down