diff --git a/plugins/mistral/modelgauge/suts/mistral_client.py b/plugins/mistral/modelgauge/suts/mistral_client.py index fcecbe3f..a456e6b8 100644 --- a/plugins/mistral/modelgauge/suts/mistral_client.py +++ b/plugins/mistral/modelgauge/suts/mistral_client.py @@ -49,13 +49,10 @@ def client(self) -> Mistral: ) return self._client - def request(self, req: dict): - response = None - if self.client.chat.sdk_configuration._hooks.before_request_hooks: - # work around bug in client - self.client.chat.sdk_configuration._hooks.before_request_hooks = [] + @staticmethod + def _make_request(endpoint, kwargs: dict): try: - response = self.client.chat.complete(**req) + response = endpoint(**kwargs) return response # TODO check if this actually happens except HTTPValidationError as exc: @@ -66,3 +63,20 @@ def request(self, req: dict): # TODO what else can happen? except Exception as exc: raise (exc) + + def request(self, req: dict): + if self.client.chat.sdk_configuration._hooks.before_request_hooks: + # work around bug in client + self.client.chat.sdk_configuration._hooks.before_request_hooks = [] + return self._make_request(self.client.chat.complete, req) + + def score_conversation(self, model, prompt, response): + """Returns moderation object for a conversation.""" + req = { + "model": model, + "inputs": [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, + ], + } + return self._make_request(self.client.classifiers.moderate_chat, req) diff --git a/plugins/mistral/modelgauge/suts/mistral_sut.py b/plugins/mistral/modelgauge/suts/mistral_sut.py index 52c9cdbf..cc708134 100644 --- a/plugins/mistral/modelgauge/suts/mistral_sut.py +++ b/plugins/mistral/modelgauge/suts/mistral_sut.py @@ -1,6 +1,7 @@ +import warnings from typing import Optional -from mistralai.models import ChatCompletionResponse +from mistralai.models import ChatCompletionResponse, ClassificationResponse from modelgauge.prompt import TextPrompt from modelgauge.secret_values import InjectSecret from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse @@ -19,6 +20,7 @@ class MistralAIRequest(BaseModel): messages: list[dict] temperature: Optional[float] = None max_tokens: Optional[int] + n: int = 1 # Number of completions to generate. class MistralAIResponse(ChatCompletionResponse): @@ -37,12 +39,10 @@ def __init__( self, uid: str, model_name: str, - model_version: str, api_key: MistralAIAPIKey, ): super().__init__(uid) self.model_name = model_name - self.model_version = model_version self._api_key = api_key self._client = None @@ -73,17 +73,119 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo return SUTResponse(completions=completions) -MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey) +class MistralAIResponseWithModerations(BaseModel): + """Mistral's ChatCompletionResponse object + moderation scores.""" -model_name = "ministral-8b" -model_version = "2410" -model_full_name = "ministral-8b-latest" # Mistral's endpoint schema -model_uid = f"mistralai-{model_name}-{model_version}" -SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY) + response: ChatCompletionResponse # Contains multiple completions. + moderations: dict[int, ClassificationResponse] # Keys correspond to a choice's index field. -model_name = "mistral-large" -model_version = "2411" -model_full_name = "mistral-large-latest" # Mistral's endpoint schema -model_uid = f"mistralai-{model_name}-{model_version}" -SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY) +@modelgauge_sut(capabilities=[AcceptsTextPrompt]) +class MistralAIModeratedSut(PromptResponseSUT): + """A MistralAI SUT hosted on MistralAI with safety moderation based on the following recipe: + https://colab.research.google.com/github/mistralai/cookbook/blob/main/mistral/moderation/system-level-guardrails.ipynb#scrollTo=OlnXFkLo5sKX + """ + + def __init__( + self, + uid: str, + model_name: str, + moderation_model_name: str, + num_generations: int, + temperature: float, + threshold: float, + api_key: MistralAIAPIKey, + ): + assert ( + num_generations > 1 + ), "The moderation strategy uses a sampling-based mechanism. num_generations should be greater than 1." + super().__init__(uid) + self.model_name = model_name + self.moderation_model_name = moderation_model_name + self.num_generations = num_generations + self.temperature = temperature + self.threshold = threshold + self._api_key = api_key + self._client = None + + @property + def client(self): + if not self._client: + self._client = MistralAIClient(self.model_name, self._api_key) + return self._client + + def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest: + # Warn if prompt options are overridden. + if prompt.options.temperature is not None and prompt.options.temperature != self.temperature: + warnings.warn( + f"This SUT overrides the prompt's temperature value of {prompt.options.temperature} to {self.temperature}." + ) + if prompt.options.num_completions is not None and prompt.options.num_completions != self.num_generations: + warnings.warn( + f"This SUT overrides the prompt's num_completions value of {prompt.options.num_completions} to {self.num_generations}." + ) + + args = {"model": self.model_name, "messages": [{"role": _USER_ROLE, "content": prompt.text}]} + if prompt.options.max_tokens is not None: + args["max_tokens"] = prompt.options.max_tokens + return MistralAIRequest(temperature=self.temperature, n=self.num_generations, **args) + + def evaluate(self, request: MistralAIRequest) -> MistralAIResponseWithModerations: + response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore + assert ( + len(response.choices) == self.num_generations + ), f"Expected {self.num_generations} completions, got {len(response.choices)}." + moderations = {} + for choice in response.choices: + scores = self.client.score_conversation( + self.moderation_model_name, request.messages[0]["content"], choice.message.content + ) + moderations[choice.index] = scores + return MistralAIResponseWithModerations(response=response, moderations=moderations) + + def translate_response(self, request: MistralAIRequest, response: MistralAIResponseWithModerations) -> SUTResponse: + # Sort list of completions by moderation score. + scores_list = [] + for choice in response.response.choices: + text = choice.message.content + assert text is not None + scores = response.moderations[choice.index] + assert len(scores.results) == 1, f"Expected 1 moderation object, got {len(scores.results)}." + category_scores = scores.results[0].category_scores + max_score = max(category_scores.values()) + scores_list.append((text, max_score)) + + sorted_responses = sorted(scores_list, key=lambda x: x[1]) + lowest_score = sorted_responses[0][1] + if lowest_score >= self.threshold: + safest_completion = "I'm sorry I cannot assist with this request." + else: + safest_completion = str(sorted_responses[0][0]) + return SUTResponse(completions=[SUTCompletion(text=str(safest_completion))]) + + +def register_suts_for_model(model_name): + MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey) + # Register standard SUT. + model_uid = f"mistralai-{model_name}" + SUTS.register(MistralAISut, model_uid, model_name, MISTRAL_API_KEY) + # Register moderated SUT. + moderation_model_name = "mistral-moderation-latest" + num_generations = 5 + temperature = 0.3 + threshold = 0.2 + moderated_model_uid = f"mistralai-{model_name}-moderated" + SUTS.register( + MistralAIModeratedSut, + moderated_model_uid, + model_name, + moderation_model_name, + num_generations, + temperature, + threshold, + MISTRAL_API_KEY, + ) + + +register_suts_for_model("ministral-8b-2410") +register_suts_for_model("mistral-large-2411") diff --git a/plugins/mistral/tests/test_mistral_sut.py b/plugins/mistral/tests/test_mistral_sut.py index 29a65ee9..70d21321 100644 --- a/plugins/mistral/tests/test_mistral_sut.py +++ b/plugins/mistral/tests/test_mistral_sut.py @@ -1,9 +1,20 @@ import pytest -from mistralai.models import AssistantMessage, ChatCompletionChoice, UsageInfo +from mistralai.models import ( + AssistantMessage, + ChatCompletionChoice, + ClassificationObject, + ClassificationResponse, + UsageInfo, +) from modelgauge.prompt import SUTOptions, TextPrompt from modelgauge.sut import SUTCompletion, SUTResponse from modelgauge.suts.mistral_client import MistralAIAPIKey -from modelgauge.suts.mistral_sut import MistralAIResponse, MistralAISut +from modelgauge.suts.mistral_sut import ( + MistralAIResponse, + MistralAIResponseWithModerations, + MistralAIModeratedSut, + MistralAISut, +) @pytest.fixture @@ -13,6 +24,7 @@ def req(): "messages": [{"role": "user", "content": "Why did the chicken cross the road?"}], "max_tokens": 91, "temperature": 0.3, + "n": 1, } @@ -41,7 +53,73 @@ def response(): @pytest.fixture def sut(): - return MistralAISut("ministral-8b-latest", "ministral-8b-latest", "latest", MistralAIAPIKey("fake")) + return MistralAISut("ministral-8b-latest", "ministral-8b-latest", MistralAIAPIKey("fake")) + + +@pytest.fixture +def moderated_req(): + return { + "model": "ministral-8b-latest", + "messages": [{"role": "user", "content": "Why did the chicken cross the road?"}], + "max_tokens": 91, + "temperature": 0.3, + "n": 2, + } + + +def _make_moderation(score): + return ClassificationResponse( + id="fake-id", + model="moderator", + results=[ + ClassificationObject( + categories={"hazard_1": False, "hazard_2": False}, category_scores={"hazard_1": 0.1, "hazard_2": 0.2} + ) + ], + ) + + +@pytest.fixture +def moderated_response(): + response = MistralAIResponse( + id="ed6c8eccd53e4b319a7bc566f6a53357", + object="chat.completion", + model="ministral-8b-latest", + created=1731977771, + usage=UsageInfo(prompt_tokens=11, completion_tokens=22, total_tokens=33), + choices=[ + ChatCompletionChoice( + index=0, + message=AssistantMessage( + content="Response 1", + tool_calls=None, + prefix=False, + role="assistant", + ), + finish_reason="stop", + ), + ChatCompletionChoice( + index=0, + message=AssistantMessage( + content="Response 2", + tool_calls=None, + prefix=False, + role="assistant", + ), + finish_reason="stop", + ), + ], + ) + return MistralAIResponseWithModerations( + response=response, moderations={0: _make_moderation(0.1), 1: _make_moderation(0.2)} + ) + + +@pytest.fixture +def moderated_sut(): + return MistralAIModeratedSut( + "ministral-8b-latest", "ministral-8b-latest", "moderator", 2, 0.3, 0.3, MistralAIAPIKey("fake") + ) class TestMistralAISut: @@ -55,3 +133,28 @@ def test_request(self, sut, req): def test_response(self, sut, req, response): resp = sut.translate_response(request=req, response=response) assert resp == SUTResponse(completions=[SUTCompletion(text="The classic joke has several variations")]) + + +class TestMistralAIModeratedSut: + + @pytest.mark.parametrize("prompt_temp,prompt_num_completions", [(None, None), (0.3, 3), (0.1, 1000)]) + def test_request(self, moderated_sut, moderated_req, prompt_temp, prompt_num_completions): + translated_req = moderated_sut.translate_text_prompt( + TextPrompt( + text="Why did the chicken cross the road?", + options=SUTOptions(temperature=prompt_temp, max_tokens=91), + num_completions=prompt_num_completions, + ) + ) + assert translated_req.model_dump(exclude_none=True) == moderated_req + + def test_response(self, moderated_sut, moderated_req, moderated_response): + resp = moderated_sut.translate_response(request=moderated_req, response=moderated_response) + assert resp == SUTResponse(completions=[SUTCompletion(text="Response 1")]) + + def test_response_over_safety_threshold(self, moderated_req, moderated_response): + sut = MistralAIModeratedSut( + "ministral-8b-latest", "ministral-8b-latest", "moderator", 2, 0.3, 0.001, MistralAIAPIKey("fake") + ) + resp = sut.translate_response(request=moderated_req, response=moderated_response) + assert resp == SUTResponse(completions=[SUTCompletion(text="I'm sorry I cannot assist with this request.")])