Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral SUTs with moderation #740

Merged
merged 9 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions plugins/mistral/modelgauge/suts/mistral_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,10 @@ def client(self) -> Mistral:
)
return self._client

def request(self, req: dict):
response = None
if self.client.chat.sdk_configuration._hooks.before_request_hooks:
# work around bug in client
self.client.chat.sdk_configuration._hooks.before_request_hooks = []
@staticmethod
def _make_request(endpoint, kwargs: dict):
try:
response = self.client.chat.complete(**req)
response = endpoint(**kwargs)
return response
# TODO check if this actually happens
except HTTPValidationError as exc:
Expand All @@ -66,3 +63,20 @@ def request(self, req: dict):
# TODO what else can happen?
except Exception as exc:
raise (exc)

def request(self, req: dict):
if self.client.chat.sdk_configuration._hooks.before_request_hooks:
# work around bug in client
self.client.chat.sdk_configuration._hooks.before_request_hooks = []
return self._make_request(self.client.chat.complete, req)

def score_conversation(self, model, prompt, response):
"""Returns moderation object for a conversation."""
req = {
"model": model,
"inputs": [
{"role": "user", "content": prompt},
{"role": "assistant", "content": response},
],
}
return self._make_request(self.client.classifiers.moderate_chat, req)
130 changes: 116 additions & 14 deletions plugins/mistral/modelgauge/suts/mistral_sut.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Optional

from mistralai.models import ChatCompletionResponse
from mistralai.models import ChatCompletionResponse, ClassificationResponse
from modelgauge.prompt import TextPrompt
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
Expand All @@ -19,6 +20,7 @@ class MistralAIRequest(BaseModel):
messages: list[dict]
temperature: Optional[float] = None
max_tokens: Optional[int]
n: int = 1 # Number of completions to generate.


class MistralAIResponse(ChatCompletionResponse):
Expand All @@ -37,12 +39,10 @@ def __init__(
self,
uid: str,
model_name: str,
model_version: str,
api_key: MistralAIAPIKey,
):
super().__init__(uid)
self.model_name = model_name
self.model_version = model_version
self._api_key = api_key
self._client = None

Expand Down Expand Up @@ -73,17 +73,119 @@ def translate_response(self, request: MistralAIRequest, response: MistralAIRespo
return SUTResponse(completions=completions)


MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey)
class MistralAIResponseWithModerations(BaseModel):
"""Mistral's ChatCompletionResponse object + moderation scores."""

model_name = "ministral-8b"
model_version = "2410"
model_full_name = "ministral-8b-latest" # Mistral's endpoint schema
model_uid = f"mistralai-{model_name}-{model_version}"
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY)
response: ChatCompletionResponse # Contains multiple completions.
moderations: dict[int, ClassificationResponse] # Keys correspond to a choice's index field.


model_name = "mistral-large"
model_version = "2411"
model_full_name = "mistral-large-latest" # Mistral's endpoint schema
model_uid = f"mistralai-{model_name}-{model_version}"
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY)
@modelgauge_sut(capabilities=[AcceptsTextPrompt])
class MistralAIModeratedSut(PromptResponseSUT):
"""A MistralAI SUT hosted on MistralAI with safety moderation based on the following recipe:
https://colab.research.google.com/github/mistralai/cookbook/blob/main/mistral/moderation/system-level-guardrails.ipynb#scrollTo=OlnXFkLo5sKX
"""

def __init__(
self,
uid: str,
model_name: str,
moderation_model_name: str,
num_generations: int,
temperature: float,
threshold: float,
api_key: MistralAIAPIKey,
):
assert (
num_generations > 1
), "The moderation strategy uses a sampling-based mechanism. num_generations should be greater than 1."
super().__init__(uid)
self.model_name = model_name
self.moderation_model_name = moderation_model_name
self.num_generations = num_generations
self.temperature = temperature
self.threshold = threshold
self._api_key = api_key
self._client = None

@property
def client(self):
if not self._client:
self._client = MistralAIClient(self.model_name, self._api_key)
return self._client

def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest:
# Warn if prompt options are overridden.
if prompt.options.temperature is not None and prompt.options.temperature != self.temperature:
dhosterman marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
f"This SUT overrides the prompt's temperature value of {prompt.options.temperature} to {self.temperature}."
)
if prompt.options.num_completions is not None and prompt.options.num_completions != self.num_generations:
warnings.warn(
f"This SUT overrides the prompt's num_completions value of {prompt.options.num_completions} to {self.num_generations}."
)

args = {"model": self.model_name, "messages": [{"role": _USER_ROLE, "content": prompt.text}]}
if prompt.options.max_tokens is not None:
args["max_tokens"] = prompt.options.max_tokens
return MistralAIRequest(temperature=self.temperature, n=self.num_generations, **args)

def evaluate(self, request: MistralAIRequest) -> MistralAIResponseWithModerations:
response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore
assert (
len(response.choices) == self.num_generations
), f"Expected {self.num_generations} completions, got {len(response.choices)}."
moderations = {}
for choice in response.choices:
scores = self.client.score_conversation(
self.moderation_model_name, request.messages[0]["content"], choice.message.content
)
moderations[choice.index] = scores
return MistralAIResponseWithModerations(response=response, moderations=moderations)

def translate_response(self, request: MistralAIRequest, response: MistralAIResponseWithModerations) -> SUTResponse:
# Sort list of completions by moderation score.
scores_list = []
for choice in response.response.choices:
text = choice.message.content
assert text is not None
scores = response.moderations[choice.index]
assert len(scores.results) == 1, f"Expected 1 moderation object, got {len(scores.results)}."
category_scores = scores.results[0].category_scores
max_score = max(category_scores.values())
scores_list.append((text, max_score))

sorted_responses = sorted(scores_list, key=lambda x: x[1])
lowest_score = sorted_responses[0][1]
if lowest_score >= self.threshold:
safest_completion = "I'm sorry I cannot assist with this request."
dhosterman marked this conversation as resolved.
Show resolved Hide resolved
else:
safest_completion = str(sorted_responses[0][0])
return SUTResponse(completions=[SUTCompletion(text=str(safest_completion))])


def register_suts_for_model(model_name):
MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey)
# Register standard SUT.
model_uid = f"mistralai-{model_name}"
SUTS.register(MistralAISut, model_uid, model_name, MISTRAL_API_KEY)
# Register moderated SUT.
moderation_model_name = "mistral-moderation-latest"
num_generations = 5
temperature = 0.3
threshold = 0.2
moderated_model_uid = f"mistralai-{model_name}-moderated"
SUTS.register(
MistralAIModeratedSut,
moderated_model_uid,
model_name,
moderation_model_name,
num_generations,
temperature,
threshold,
MISTRAL_API_KEY,
)


register_suts_for_model("ministral-8b-2410")
register_suts_for_model("mistral-large-2411")
109 changes: 106 additions & 3 deletions plugins/mistral/tests/test_mistral_sut.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import pytest
from mistralai.models import AssistantMessage, ChatCompletionChoice, UsageInfo
from mistralai.models import (
AssistantMessage,
ChatCompletionChoice,
ClassificationObject,
ClassificationResponse,
UsageInfo,
)
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.sut import SUTCompletion, SUTResponse
from modelgauge.suts.mistral_client import MistralAIAPIKey
from modelgauge.suts.mistral_sut import MistralAIResponse, MistralAISut
from modelgauge.suts.mistral_sut import (
MistralAIResponse,
MistralAIResponseWithModerations,
MistralAIModeratedSut,
MistralAISut,
)


@pytest.fixture
Expand All @@ -13,6 +24,7 @@ def req():
"messages": [{"role": "user", "content": "Why did the chicken cross the road?"}],
"max_tokens": 91,
"temperature": 0.3,
"n": 1,
}


Expand Down Expand Up @@ -41,7 +53,73 @@ def response():

@pytest.fixture
def sut():
return MistralAISut("ministral-8b-latest", "ministral-8b-latest", "latest", MistralAIAPIKey("fake"))
return MistralAISut("ministral-8b-latest", "ministral-8b-latest", MistralAIAPIKey("fake"))


@pytest.fixture
def moderated_req():
return {
"model": "ministral-8b-latest",
"messages": [{"role": "user", "content": "Why did the chicken cross the road?"}],
"max_tokens": 91,
"temperature": 0.3,
"n": 2,
}


def _make_moderation(score):
return ClassificationResponse(
id="fake-id",
model="moderator",
results=[
ClassificationObject(
categories={"hazard_1": False, "hazard_2": False}, category_scores={"hazard_1": 0.1, "hazard_2": 0.2}
)
],
)


@pytest.fixture
def moderated_response():
response = MistralAIResponse(
id="ed6c8eccd53e4b319a7bc566f6a53357",
object="chat.completion",
model="ministral-8b-latest",
created=1731977771,
usage=UsageInfo(prompt_tokens=11, completion_tokens=22, total_tokens=33),
choices=[
ChatCompletionChoice(
index=0,
message=AssistantMessage(
content="Response 1",
tool_calls=None,
prefix=False,
role="assistant",
),
finish_reason="stop",
),
ChatCompletionChoice(
index=0,
message=AssistantMessage(
content="Response 2",
tool_calls=None,
prefix=False,
role="assistant",
),
finish_reason="stop",
),
],
)
return MistralAIResponseWithModerations(
response=response, moderations={0: _make_moderation(0.1), 1: _make_moderation(0.2)}
)


@pytest.fixture
def moderated_sut():
return MistralAIModeratedSut(
"ministral-8b-latest", "ministral-8b-latest", "moderator", 2, 0.3, 0.3, MistralAIAPIKey("fake")
)


class TestMistralAISut:
Expand All @@ -55,3 +133,28 @@ def test_request(self, sut, req):
def test_response(self, sut, req, response):
resp = sut.translate_response(request=req, response=response)
assert resp == SUTResponse(completions=[SUTCompletion(text="The classic joke has several variations")])


class TestMistralAIModeratedSut:

@pytest.mark.parametrize("prompt_temp,prompt_num_completions", [(None, None), (0.3, 3), (0.1, 1000)])
def test_request(self, moderated_sut, moderated_req, prompt_temp, prompt_num_completions):
translated_req = moderated_sut.translate_text_prompt(
TextPrompt(
text="Why did the chicken cross the road?",
options=SUTOptions(temperature=prompt_temp, max_tokens=91),
num_completions=prompt_num_completions,
)
)
assert translated_req.model_dump(exclude_none=True) == moderated_req

def test_response(self, moderated_sut, moderated_req, moderated_response):
resp = moderated_sut.translate_response(request=moderated_req, response=moderated_response)
assert resp == SUTResponse(completions=[SUTCompletion(text="Response 1")])

def test_response_over_safety_threshold(self, moderated_req, moderated_response):
sut = MistralAIModeratedSut(
"ministral-8b-latest", "ministral-8b-latest", "moderator", 2, 0.3, 0.001, MistralAIAPIKey("fake")
)
resp = sut.translate_response(request=moderated_req, response=moderated_response)
assert resp == SUTResponse(completions=[SUTCompletion(text="I'm sorry I cannot assist with this request.")])
Loading