Skip to content

Commit

Permalink
Add mistral large sut hosted on vertex ai (#629)
Browse files Browse the repository at this point in the history
* first pass at mistral large sut hosted on vertex ai

* Mistral SUT

* minstral-8b-instruct is not available on vertex at the moment

* not used

* add missing google auth dependency

* add region as secret; fix model UID to include the version

* rename

* rename for clarity

* noop; lint

* refactor to avoid conflicts with the other mistral SUT

* add ministral-8b sut

* fix version string; switch mistral-large to MistralAI and disable it in the VertexAI sut, with comment explaining why

* update mistral plugin name

* remove unneeded mistral package

* add temperature

* noop; add todo for later if we want to use vertex

* add retry logic

* add max_tokens; fix thresholds for exponential backoff

* add max_tokens

* added max_tokens and temperature to test

* noop; clean debug code

* update uid and model name

* update uid to keep this model distinct from the same model hosted by MistralAI

* update authentication for vertex AI, and add the vendor name to the SUT uid
  • Loading branch information
rogthefrog authored Nov 21, 2024
1 parent c494e58 commit 0c86e60
Show file tree
Hide file tree
Showing 14 changed files with 1,557 additions and 581 deletions.
1 change: 1 addition & 0 deletions plugins/mistral/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Plugin for models hosted on MistralAI.
66 changes: 66 additions & 0 deletions plugins/mistral/modelgauge/suts/mistral_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from mistralai import Mistral

from mistralai.models import HTTPValidationError, SDKError
from mistralai.utils import BackoffStrategy, RetryConfig

from modelgauge.secret_values import RequiredSecret, SecretDescription


BACKOFF_INITIAL_MILLIS = 500
BACKOFF_MAX_INTERVAL_MILLIS = 10_000
BACKOFF_EXPONENT = 1.1
BACKOFF_MAX_ELAPSED_MILLIS = 60_000


class MistralAIAPIKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="mistralai",
key="api_key",
instructions="MistralAI API key. See https://docs.mistral.ai/getting-started/quickstart/",
)


class MistralAIClient:
def __init__(
self,
model_name: str,
api_key: MistralAIAPIKey,
):
self.model_name = model_name
self.api_key = api_key.value
self._client = None

@property
def client(self) -> Mistral:
if not self._client:
self._client = Mistral(
api_key=self.api_key,
retry_config=RetryConfig(
"backoff",
BackoffStrategy(
BACKOFF_INITIAL_MILLIS,
BACKOFF_MAX_INTERVAL_MILLIS,
BACKOFF_EXPONENT,
BACKOFF_MAX_INTERVAL_MILLIS,
),
True,
),
)
return self._client

def request(self, req: dict):
response = None
try:
response = self.client.chat.complete(**req)
return response
# TODO check if this actually happens
except HTTPValidationError as exc:
raise (exc)
# TODO check if the retry strategy takes care of this
except SDKError as exc:
raise (exc)
# TODO what else can happen?
except Exception as exc:
raise (exc)
89 changes: 89 additions & 0 deletions plugins/mistral/modelgauge/suts/mistral_sut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Optional

from mistralai.models import ChatCompletionResponse
from modelgauge.prompt import TextPrompt
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.suts.mistral_client import MistralAIAPIKey, MistralAIClient

from pydantic import BaseModel

_USER_ROLE = "user"


class MistralAIRequest(BaseModel):
model: str
messages: list[dict]
temperature: Optional[float] = None
max_tokens: Optional[int]


class MistralAIResponse(ChatCompletionResponse):
"""The ChatCompletionResponse class from Mistral matches our Response
objects now, but we subclass it for consistency and so we can adjust it
in case the upstream object changes."""

pass


@modelgauge_sut(capabilities=[AcceptsTextPrompt])
class MistralAISut(PromptResponseSUT):
"""A MistralAI SUT hosted on MistralAI."""

def __init__(
self,
uid: str,
model_name: str,
model_version: str,
api_key: MistralAIAPIKey,
):
super().__init__(uid)
self.model_name = model_name
self.model_version = model_version
self._api_key = api_key
self._client = None

@property
def client(self):
if not self._client:
self._client = MistralAIClient(self.model_name, self._api_key)
return self._client

def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest:
args = {"model": self.model_name, "messages": [{"role": _USER_ROLE, "content": prompt.text}]}
if prompt.options.temperature is not None:
args["temperature"] = prompt.options.temperature
if prompt.options.max_tokens is not None:
args["max_tokens"] = prompt.options.max_tokens
return MistralAIRequest(**args)

def evaluate(self, request: MistralAIRequest) -> ChatCompletionResponse:
response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore
return response

def translate_response(self, request: MistralAIRequest, response: MistralAIResponse) -> SUTResponse:
completions = []
for choice in response.choices:
text = choice.message.content
assert text is not None
completions.append(SUTCompletion(text=str(text)))
return SUTResponse(completions=completions)


MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey)

model_name = "ministral-8b"
model_version = "2410"
model_full_name = "ministral-8b-latest" # Mistral's endpoint schema
model_uid = f"mistralai-{model_name}-{model_version}"
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY)


model_name = "mistral-large"
model_version = "2411"
model_full_name = "mistral-large-latest" # Mistral's endpoint schema
model_uid = f"mistralai-{model_name}-{model_version}"
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY)
Loading

0 comments on commit 0c86e60

Please sign in to comment.