From ea311f933f3f1975f11aed3d12cdc33c69c52e52 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 9 Dec 2024 13:14:29 -1000 Subject: [PATCH 1/9] Create moderated Mistral SUT --- .../mistral/modelgauge/suts/mistral_client.py | 12 ++ .../mistral/modelgauge/suts/mistral_sut.py | 118 +++++++++++++++++- 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_client.py b/plugins/mistral/modelgauge/suts/mistral_client.py index fcecbe3f..2dfbead4 100644 --- a/plugins/mistral/modelgauge/suts/mistral_client.py +++ b/plugins/mistral/modelgauge/suts/mistral_client.py @@ -66,3 +66,15 @@ def request(self, req: dict): # TODO what else can happen? except Exception as exc: raise (exc) + + def score_conversation(self, model, prompt, response): + """Returns moderation object for a conversation.""" + # TODO: Wrap in try-except block. + response = self.client.classifiers.moderate_chat( + model=model, + inputs=[ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, + ], + ) + return response diff --git a/plugins/mistral/modelgauge/suts/mistral_sut.py b/plugins/mistral/modelgauge/suts/mistral_sut.py index 52c9cdbf..5aa4252b 100644 --- a/plugins/mistral/modelgauge/suts/mistral_sut.py +++ b/plugins/mistral/modelgauge/suts/mistral_sut.py @@ -1,6 +1,7 @@ +import warnings from typing import Optional -from mistralai.models import ChatCompletionResponse +from mistralai.models import ChatCompletionResponse, ClassificationResponse from modelgauge.prompt import TextPrompt from modelgauge.secret_values import InjectSecret from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse @@ -19,6 +20,7 @@ class MistralAIRequest(BaseModel): messages: list[dict] temperature: Optional[float] = None max_tokens: Optional[int] + n: int = 1 # Number of completions to generate. class MistralAIResponse(ChatCompletionResponse): @@ -73,8 +75,103 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo return SUTResponse(completions=completions) +class MistralAIResponseWithModerations(BaseModel): + """Mistral's ChatCompletionResponse object + moderation scores.""" + + response: ChatCompletionResponse # Contains multiple completions. + moderations: dict[int, ClassificationResponse] # Keys correspond to a choice's index field. + + +@modelgauge_sut(capabilities=[AcceptsTextPrompt]) +class MistralAIModeratedSut(PromptResponseSUT): + """A MistralAI SUT hosted on MistralAI with safety moderation based on the following recipe: + https://colab.research.google.com/github/mistralai/cookbook/blob/main/mistral/moderation/system-level-guardrails.ipynb#scrollTo=OlnXFkLo5sKX + """ + + def __init__( + self, + uid: str, + model_name: str, + model_version: str, + moderation_model_name: str, + num_generations: int, + temperature: float, + threshold: float, + api_key: MistralAIAPIKey, + ): + assert num_generations > 1, ( + "The moderation strategy uses a sampling-based mechanism. num_generations should " "be greater than 1." + ) + super().__init__(uid) + self.model_name = model_name + self.model_version = model_version + self.moderation_model_name = moderation_model_name + self.num_generations = num_generations + self.temperature = temperature + self.threshold = threshold + self._api_key = api_key + self._client = None + + @property + def client(self): + if not self._client: + self._client = MistralAIClient(self.model_name, self._api_key) + return self._client + + def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest: + # Warn if prompt options are overridden. + if prompt.options.temperature is not None and prompt.options.temperature != self.temperature: + warnings.warn( + f"This SUT overrides the prompt's temperature value of {prompt.options.temperature} to {self.temperature}." + ) + if prompt.options.num_completions is not None and prompt.options.num_completions != self.num_generations: + warnings.warn( + f"This SUT overrides the prompt's num_completions value of {prompt.options.num_completions} to {self.num_generations}." + ) + + args = {"model": self.model_name, "messages": [{"role": _USER_ROLE, "content": prompt.text}]} + if prompt.options.max_tokens is not None: + args["max_tokens"] = prompt.options.max_tokens + return MistralAIRequest(temperature=self.temperature, n=self.num_generations, **args) + + def evaluate(self, request: MistralAIRequest) -> MistralAIResponseWithModerations: + """Override the normal SUT evaluate behavior to include moderation mechanism.""" + response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore + assert ( + len(response.choices) == self.num_generations + ), f"Expected {self.num_generations} completions, got {len(response.choices)}." + moderations = {} + for choice in response.choices: + scores = self.client.score_conversation( + self.moderation_model_name, request.messages[0]["content"], choice.message.content + ) + moderations[choice.index] = scores + return MistralAIResponseWithModerations(response=response, moderations=moderations) + + def translate_response(self, request: MistralAIRequest, response: MistralAIResponseWithModerations) -> SUTResponse: + # Sort list of completions by moderation score. + scores_list = [] + for choice in response.response.choices: + text = choice.message.content + assert text is not None + scores = response.moderations[choice.index] + assert len(scores.results) == 1, f"Expected 1 moderation object, got {len(scores.results)}." + category_scores = scores.results[0].category_scores + max_score = max(category_scores.values()) + scores_list.append((text, max_score)) + + sorted_responses = sorted(scores_list, key=lambda x: x[1]) + lowest_score = sorted_responses[0][1] + if lowest_score >= self.threshold: + safest_completion = "I'm sorry I cannot assist with this request." + else: + safest_completion = str(sorted_responses[0][0]) + return SUTResponse(completions=[SUTCompletion(text=str(safest_completion))]) + + MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey) +# TODO: Register both moderated and unmoderated SUTs. model_name = "ministral-8b" model_version = "2410" model_full_name = "ministral-8b-latest" # Mistral's endpoint schema @@ -87,3 +184,22 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo model_full_name = "mistral-large-latest" # Mistral's endpoint schema model_uid = f"mistralai-{model_name}-{model_version}" SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY) + +moderation_model_name = "mistral-moderation-latest" +num_generations = 2 +temperature = 0.3 +threshold = 0.2 +# TODO: Add more values to UID +model_uid = f"mistralai-{model_name}-{model_version}-moderated" + +SUTS.register( + MistralAIModeratedSut, + model_uid, + model_full_name, + model_version, + moderation_model_name, + num_generations, + temperature, + threshold, + MISTRAL_API_KEY, +) From cbfe5bf10524abb21084421ec586d656b51368cd Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 9 Dec 2024 15:14:59 -1000 Subject: [PATCH 2/9] added tests --- .../mistral/modelgauge/suts/mistral_sut.py | 7 +- plugins/mistral/tests/test_mistral_sut.py | 107 +++++++++++++++++- 2 files changed, 108 insertions(+), 6 deletions(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_sut.py b/plugins/mistral/modelgauge/suts/mistral_sut.py index 5aa4252b..e57ff446 100644 --- a/plugins/mistral/modelgauge/suts/mistral_sut.py +++ b/plugins/mistral/modelgauge/suts/mistral_sut.py @@ -99,9 +99,9 @@ def __init__( threshold: float, api_key: MistralAIAPIKey, ): - assert num_generations > 1, ( - "The moderation strategy uses a sampling-based mechanism. num_generations should " "be greater than 1." - ) + assert ( + num_generations > 1 + ), "The moderation strategy uses a sampling-based mechanism. num_generations should be greater than 1." super().__init__(uid) self.model_name = model_name self.model_version = model_version @@ -135,7 +135,6 @@ def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest: return MistralAIRequest(temperature=self.temperature, n=self.num_generations, **args) def evaluate(self, request: MistralAIRequest) -> MistralAIResponseWithModerations: - """Override the normal SUT evaluate behavior to include moderation mechanism.""" response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore assert ( len(response.choices) == self.num_generations diff --git a/plugins/mistral/tests/test_mistral_sut.py b/plugins/mistral/tests/test_mistral_sut.py index 29a65ee9..47fa9714 100644 --- a/plugins/mistral/tests/test_mistral_sut.py +++ b/plugins/mistral/tests/test_mistral_sut.py @@ -1,9 +1,20 @@ import pytest -from mistralai.models import AssistantMessage, ChatCompletionChoice, UsageInfo +from mistralai.models import ( + AssistantMessage, + ChatCompletionChoice, + ClassificationObject, + ClassificationResponse, + UsageInfo, +) from modelgauge.prompt import SUTOptions, TextPrompt from modelgauge.sut import SUTCompletion, SUTResponse from modelgauge.suts.mistral_client import MistralAIAPIKey -from modelgauge.suts.mistral_sut import MistralAIResponse, MistralAISut +from modelgauge.suts.mistral_sut import ( + MistralAIResponse, + MistralAIResponseWithModerations, + MistralAIModeratedSut, + MistralAISut, +) @pytest.fixture @@ -13,6 +24,7 @@ def req(): "messages": [{"role": "user", "content": "Why did the chicken cross the road?"}], "max_tokens": 91, "temperature": 0.3, + "n": 1, } @@ -44,6 +56,72 @@ def sut(): return MistralAISut("ministral-8b-latest", "ministral-8b-latest", "latest", MistralAIAPIKey("fake")) +@pytest.fixture +def moderated_req(): + return { + "model": "ministral-8b-latest", + "messages": [{"role": "user", "content": "Why did the chicken cross the road?"}], + "max_tokens": 91, + "temperature": 0.3, + "n": 2, + } + + +def _make_moderation(score): + return ClassificationResponse( + id="fake-id", + model="moderator", + results=[ + ClassificationObject( + categories={"hazard_1": False, "hazard_2": False}, category_scores={"hazard_1": 0.1, "hazard_2": 0.2} + ) + ], + ) + + +@pytest.fixture +def moderated_response(): + response = MistralAIResponse( + id="ed6c8eccd53e4b319a7bc566f6a53357", + object="chat.completion", + model="ministral-8b-latest", + created=1731977771, + usage=UsageInfo(prompt_tokens=11, completion_tokens=22, total_tokens=33), + choices=[ + ChatCompletionChoice( + index=0, + message=AssistantMessage( + content="Response 1", + tool_calls=None, + prefix=False, + role="assistant", + ), + finish_reason="stop", + ), + ChatCompletionChoice( + index=0, + message=AssistantMessage( + content="Response 2", + tool_calls=None, + prefix=False, + role="assistant", + ), + finish_reason="stop", + ), + ], + ) + return MistralAIResponseWithModerations( + response=response, moderations={0: _make_moderation(0.1), 1: _make_moderation(0.2)} + ) + + +@pytest.fixture +def moderated_sut(): + return MistralAIModeratedSut( + "ministral-8b-latest", "ministral-8b-latest", "latest", "moderator", 2, 0.3, 0.3, MistralAIAPIKey("fake") + ) + + class TestMistralAISut: def test_request(self, sut, req): @@ -55,3 +133,28 @@ def test_request(self, sut, req): def test_response(self, sut, req, response): resp = sut.translate_response(request=req, response=response) assert resp == SUTResponse(completions=[SUTCompletion(text="The classic joke has several variations")]) + + +class TestMistralAIModeratedSut: + + @pytest.mark.parametrize("prompt_temp,prompt_num_completions", [(None, None), (0.3, 3), (0.1, 1000)]) + def test_request(self, moderated_sut, moderated_req, prompt_temp, prompt_num_completions): + translated_req = moderated_sut.translate_text_prompt( + TextPrompt( + text="Why did the chicken cross the road?", + options=SUTOptions(temperature=prompt_temp, max_tokens=91), + num_completions=prompt_num_completions, + ) + ) + assert translated_req.model_dump(exclude_none=True) == moderated_req + + def test_response(self, moderated_sut, moderated_req, moderated_response): + resp = moderated_sut.translate_response(request=moderated_req, response=moderated_response) + assert resp == SUTResponse(completions=[SUTCompletion(text="Response 1")]) + + def test_response_over_safety_threshold(self, moderated_req, moderated_response): + sut = MistralAIModeratedSut( + "ministral-8b-latest", "ministral-8b-latest", "latest", "moderator", 2, 0.3, 0.001, MistralAIAPIKey("fake") + ) + resp = sut.translate_response(request=moderated_req, response=moderated_response) + assert resp == SUTResponse(completions=[SUTCompletion(text="I'm sorry I cannot assist with this request.")]) From d187ab4d1e6e4da87f9319d6d1e46cdebf05bcb8 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 9 Dec 2024 15:22:24 -1000 Subject: [PATCH 3/9] Wrap moderation request in retry block --- .../mistral/modelgauge/suts/mistral_client.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_client.py b/plugins/mistral/modelgauge/suts/mistral_client.py index 2dfbead4..835221db 100644 --- a/plugins/mistral/modelgauge/suts/mistral_client.py +++ b/plugins/mistral/modelgauge/suts/mistral_client.py @@ -49,13 +49,10 @@ def client(self) -> Mistral: ) return self._client - def request(self, req: dict): - response = None - if self.client.chat.sdk_configuration._hooks.before_request_hooks: - # work around bug in client - self.client.chat.sdk_configuration._hooks.before_request_hooks = [] + @staticmethod + def _retry_request(endpoint, kwargs: dict): try: - response = self.client.chat.complete(**req) + response = endpoint(**kwargs) return response # TODO check if this actually happens except HTTPValidationError as exc: @@ -67,14 +64,19 @@ def request(self, req: dict): except Exception as exc: raise (exc) + def request(self, req: dict): + if self.client.chat.sdk_configuration._hooks.before_request_hooks: + # work around bug in client + self.client.chat.sdk_configuration._hooks.before_request_hooks = [] + return self._retry_request(self.client.chat.complete, req) + def score_conversation(self, model, prompt, response): """Returns moderation object for a conversation.""" - # TODO: Wrap in try-except block. - response = self.client.classifiers.moderate_chat( - model=model, - inputs=[ + req = { + "model": model, + "inputs": [ {"role": "user", "content": prompt}, {"role": "assistant", "content": response}, ], - ) - return response + } + return self._retry_request(self.client.classifiers.moderate_chat, req) From 7b69b8ae2f1733cf7ed8eddad6e35386727cc1b8 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 10 Dec 2024 07:17:32 -1000 Subject: [PATCH 4/9] model version is a part of its name --- .../mistral/modelgauge/suts/mistral_sut.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_sut.py b/plugins/mistral/modelgauge/suts/mistral_sut.py index e57ff446..c6f76126 100644 --- a/plugins/mistral/modelgauge/suts/mistral_sut.py +++ b/plugins/mistral/modelgauge/suts/mistral_sut.py @@ -39,12 +39,10 @@ def __init__( self, uid: str, model_name: str, - model_version: str, api_key: MistralAIAPIKey, ): super().__init__(uid) self.model_name = model_name - self.model_version = model_version self._api_key = api_key self._client = None @@ -92,19 +90,17 @@ def __init__( self, uid: str, model_name: str, - model_version: str, moderation_model_name: str, num_generations: int, temperature: float, threshold: float, api_key: MistralAIAPIKey, ): - assert ( - num_generations > 1 - ), "The moderation strategy uses a sampling-based mechanism. num_generations should be greater than 1." + assert num_generations > 1, ( + "The moderation strategy uses a sampling-based mechanism. num_generations should be greater than 1." + ) super().__init__(uid) self.model_name = model_name - self.model_version = model_version self.moderation_model_name = moderation_model_name self.num_generations = num_generations self.temperature = temperature @@ -171,31 +167,26 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey) # TODO: Register both moderated and unmoderated SUTs. -model_name = "ministral-8b" -model_version = "2410" -model_full_name = "ministral-8b-latest" # Mistral's endpoint schema -model_uid = f"mistralai-{model_name}-{model_version}" -SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY) +model_name = "ministral-8b-2410" # Mistral's endpoint schema +model_uid = f"mistralai-{model_name}" +SUTS.register(MistralAISut, model_uid, model_name, MISTRAL_API_KEY) -model_name = "mistral-large" -model_version = "2411" -model_full_name = "mistral-large-latest" # Mistral's endpoint schema -model_uid = f"mistralai-{model_name}-{model_version}" -SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY) +model_name = "mistral-large-2411" +model_uid = f"mistralai-{model_name}" +SUTS.register(MistralAISut, model_uid, model_name, MISTRAL_API_KEY) moderation_model_name = "mistral-moderation-latest" num_generations = 2 temperature = 0.3 threshold = 0.2 # TODO: Add more values to UID -model_uid = f"mistralai-{model_name}-{model_version}-moderated" +model_uid = f"mistralai-{model_name}-moderated" SUTS.register( MistralAIModeratedSut, model_uid, - model_full_name, - model_version, + model_name, moderation_model_name, num_generations, temperature, From 65409e4b7fd7e3a9c4d3ae7276def88a0d2368d9 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 10 Dec 2024 08:10:55 -1000 Subject: [PATCH 5/9] Clean up SUT registration --- .../mistral/modelgauge/suts/mistral_sut.py | 54 +++++++++---------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_sut.py b/plugins/mistral/modelgauge/suts/mistral_sut.py index c6f76126..c4b18e9b 100644 --- a/plugins/mistral/modelgauge/suts/mistral_sut.py +++ b/plugins/mistral/modelgauge/suts/mistral_sut.py @@ -164,32 +164,28 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo return SUTResponse(completions=[SUTCompletion(text=str(safest_completion))]) -MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey) - -# TODO: Register both moderated and unmoderated SUTs. -model_name = "ministral-8b-2410" # Mistral's endpoint schema -model_uid = f"mistralai-{model_name}" -SUTS.register(MistralAISut, model_uid, model_name, MISTRAL_API_KEY) - - -model_name = "mistral-large-2411" -model_uid = f"mistralai-{model_name}" -SUTS.register(MistralAISut, model_uid, model_name, MISTRAL_API_KEY) - -moderation_model_name = "mistral-moderation-latest" -num_generations = 2 -temperature = 0.3 -threshold = 0.2 -# TODO: Add more values to UID -model_uid = f"mistralai-{model_name}-moderated" - -SUTS.register( - MistralAIModeratedSut, - model_uid, - model_name, - moderation_model_name, - num_generations, - temperature, - threshold, - MISTRAL_API_KEY, -) +def register_suts_for_model(model_name): + MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey) + # Register standard SUT. + model_uid = f"mistralai-{model_name}" + SUTS.register(MistralAISut, model_uid, model_name, MISTRAL_API_KEY) + # Register moderated SUT. + moderation_model_name = "mistral-moderation-latest" + num_generations = 5 + temperature = 0.3 + threshold = 0.2 + moderated_model_uid = f"mistralai-{model_name}-moderated" + SUTS.register( + MistralAIModeratedSut, + moderated_model_uid, + model_name, + moderation_model_name, + num_generations, + temperature, + threshold, + MISTRAL_API_KEY, + ) + + +register_suts_for_model("ministral-8b-2410") +register_suts_for_model("mistral-large-2411") From 8282f8b08709b11e312e0f7ca28790db0a2706e9 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 10 Dec 2024 08:14:31 -1000 Subject: [PATCH 6/9] rename method --- plugins/mistral/modelgauge/suts/mistral_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_client.py b/plugins/mistral/modelgauge/suts/mistral_client.py index 835221db..ff35c2ac 100644 --- a/plugins/mistral/modelgauge/suts/mistral_client.py +++ b/plugins/mistral/modelgauge/suts/mistral_client.py @@ -50,7 +50,7 @@ def client(self) -> Mistral: return self._client @staticmethod - def _retry_request(endpoint, kwargs: dict): + def _make_request(endpoint, kwargs: dict): try: response = endpoint(**kwargs) return response From 12530f5aa674317ac76a9385e7f7e769fa68b685 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 10 Dec 2024 08:21:43 -1000 Subject: [PATCH 7/9] oops --- plugins/mistral/modelgauge/suts/mistral_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_client.py b/plugins/mistral/modelgauge/suts/mistral_client.py index ff35c2ac..a456e6b8 100644 --- a/plugins/mistral/modelgauge/suts/mistral_client.py +++ b/plugins/mistral/modelgauge/suts/mistral_client.py @@ -68,7 +68,7 @@ def request(self, req: dict): if self.client.chat.sdk_configuration._hooks.before_request_hooks: # work around bug in client self.client.chat.sdk_configuration._hooks.before_request_hooks = [] - return self._retry_request(self.client.chat.complete, req) + return self._make_request(self.client.chat.complete, req) def score_conversation(self, model, prompt, response): """Returns moderation object for a conversation.""" @@ -79,4 +79,4 @@ def score_conversation(self, model, prompt, response): {"role": "assistant", "content": response}, ], } - return self._retry_request(self.client.classifiers.moderate_chat, req) + return self._make_request(self.client.classifiers.moderate_chat, req) From 830e709b48e9ee9d39c7663467d88e333d4d3c77 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 10 Dec 2024 08:27:12 -1000 Subject: [PATCH 8/9] lint --- plugins/mistral/modelgauge/suts/mistral_sut.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/mistral/modelgauge/suts/mistral_sut.py b/plugins/mistral/modelgauge/suts/mistral_sut.py index c4b18e9b..cc708134 100644 --- a/plugins/mistral/modelgauge/suts/mistral_sut.py +++ b/plugins/mistral/modelgauge/suts/mistral_sut.py @@ -96,9 +96,9 @@ def __init__( threshold: float, api_key: MistralAIAPIKey, ): - assert num_generations > 1, ( - "The moderation strategy uses a sampling-based mechanism. num_generations should be greater than 1." - ) + assert ( + num_generations > 1 + ), "The moderation strategy uses a sampling-based mechanism. num_generations should be greater than 1." super().__init__(uid) self.model_name = model_name self.moderation_model_name = moderation_model_name From d6ddf71404e9d4c48b37af7e611be4c479f4b9e6 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 10 Dec 2024 08:41:38 -1000 Subject: [PATCH 9/9] fix tests --- plugins/mistral/tests/test_mistral_sut.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/mistral/tests/test_mistral_sut.py b/plugins/mistral/tests/test_mistral_sut.py index 47fa9714..70d21321 100644 --- a/plugins/mistral/tests/test_mistral_sut.py +++ b/plugins/mistral/tests/test_mistral_sut.py @@ -53,7 +53,7 @@ def response(): @pytest.fixture def sut(): - return MistralAISut("ministral-8b-latest", "ministral-8b-latest", "latest", MistralAIAPIKey("fake")) + return MistralAISut("ministral-8b-latest", "ministral-8b-latest", MistralAIAPIKey("fake")) @pytest.fixture @@ -118,7 +118,7 @@ def moderated_response(): @pytest.fixture def moderated_sut(): return MistralAIModeratedSut( - "ministral-8b-latest", "ministral-8b-latest", "latest", "moderator", 2, 0.3, 0.3, MistralAIAPIKey("fake") + "ministral-8b-latest", "ministral-8b-latest", "moderator", 2, 0.3, 0.3, MistralAIAPIKey("fake") ) @@ -154,7 +154,7 @@ def test_response(self, moderated_sut, moderated_req, moderated_response): def test_response_over_safety_threshold(self, moderated_req, moderated_response): sut = MistralAIModeratedSut( - "ministral-8b-latest", "ministral-8b-latest", "latest", "moderator", 2, 0.3, 0.001, MistralAIAPIKey("fake") + "ministral-8b-latest", "ministral-8b-latest", "moderator", 2, 0.3, 0.001, MistralAIAPIKey("fake") ) resp = sut.translate_response(request=moderated_req, response=moderated_response) assert resp == SUTResponse(completions=[SUTCompletion(text="I'm sorry I cannot assist with this request.")])