-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add mistral large sut hosted on vertex ai (#629)
* first pass at mistral large sut hosted on vertex ai * Mistral SUT * minstral-8b-instruct is not available on vertex at the moment * not used * add missing google auth dependency * add region as secret; fix model UID to include the version * rename * rename for clarity * noop; lint * refactor to avoid conflicts with the other mistral SUT * add ministral-8b sut * fix version string; switch mistral-large to MistralAI and disable it in the VertexAI sut, with comment explaining why * update mistral plugin name * remove unneeded mistral package * add temperature * noop; add todo for later if we want to use vertex * add retry logic * add max_tokens; fix thresholds for exponential backoff * add max_tokens * added max_tokens and temperature to test * noop; clean debug code * update uid and model name * update uid to keep this model distinct from the same model hosted by MistralAI * update authentication for vertex AI, and add the vendor name to the SUT uid
- Loading branch information
1 parent
c494e58
commit 0c86e60
Showing
14 changed files
with
1,557 additions
and
581 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Plugin for models hosted on MistralAI. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from mistralai import Mistral | ||
|
||
from mistralai.models import HTTPValidationError, SDKError | ||
from mistralai.utils import BackoffStrategy, RetryConfig | ||
|
||
from modelgauge.secret_values import RequiredSecret, SecretDescription | ||
|
||
|
||
BACKOFF_INITIAL_MILLIS = 500 | ||
BACKOFF_MAX_INTERVAL_MILLIS = 10_000 | ||
BACKOFF_EXPONENT = 1.1 | ||
BACKOFF_MAX_ELAPSED_MILLIS = 60_000 | ||
|
||
|
||
class MistralAIAPIKey(RequiredSecret): | ||
@classmethod | ||
def description(cls) -> SecretDescription: | ||
return SecretDescription( | ||
scope="mistralai", | ||
key="api_key", | ||
instructions="MistralAI API key. See https://docs.mistral.ai/getting-started/quickstart/", | ||
) | ||
|
||
|
||
class MistralAIClient: | ||
def __init__( | ||
self, | ||
model_name: str, | ||
api_key: MistralAIAPIKey, | ||
): | ||
self.model_name = model_name | ||
self.api_key = api_key.value | ||
self._client = None | ||
|
||
@property | ||
def client(self) -> Mistral: | ||
if not self._client: | ||
self._client = Mistral( | ||
api_key=self.api_key, | ||
retry_config=RetryConfig( | ||
"backoff", | ||
BackoffStrategy( | ||
BACKOFF_INITIAL_MILLIS, | ||
BACKOFF_MAX_INTERVAL_MILLIS, | ||
BACKOFF_EXPONENT, | ||
BACKOFF_MAX_INTERVAL_MILLIS, | ||
), | ||
True, | ||
), | ||
) | ||
return self._client | ||
|
||
def request(self, req: dict): | ||
response = None | ||
try: | ||
response = self.client.chat.complete(**req) | ||
return response | ||
# TODO check if this actually happens | ||
except HTTPValidationError as exc: | ||
raise (exc) | ||
# TODO check if the retry strategy takes care of this | ||
except SDKError as exc: | ||
raise (exc) | ||
# TODO what else can happen? | ||
except Exception as exc: | ||
raise (exc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from typing import Optional | ||
|
||
from mistralai.models import ChatCompletionResponse | ||
from modelgauge.prompt import TextPrompt | ||
from modelgauge.secret_values import InjectSecret | ||
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse | ||
from modelgauge.sut_capabilities import AcceptsTextPrompt | ||
from modelgauge.sut_decorator import modelgauge_sut | ||
from modelgauge.sut_registry import SUTS | ||
from modelgauge.suts.mistral_client import MistralAIAPIKey, MistralAIClient | ||
|
||
from pydantic import BaseModel | ||
|
||
_USER_ROLE = "user" | ||
|
||
|
||
class MistralAIRequest(BaseModel): | ||
model: str | ||
messages: list[dict] | ||
temperature: Optional[float] = None | ||
max_tokens: Optional[int] | ||
|
||
|
||
class MistralAIResponse(ChatCompletionResponse): | ||
"""The ChatCompletionResponse class from Mistral matches our Response | ||
objects now, but we subclass it for consistency and so we can adjust it | ||
in case the upstream object changes.""" | ||
|
||
pass | ||
|
||
|
||
@modelgauge_sut(capabilities=[AcceptsTextPrompt]) | ||
class MistralAISut(PromptResponseSUT): | ||
"""A MistralAI SUT hosted on MistralAI.""" | ||
|
||
def __init__( | ||
self, | ||
uid: str, | ||
model_name: str, | ||
model_version: str, | ||
api_key: MistralAIAPIKey, | ||
): | ||
super().__init__(uid) | ||
self.model_name = model_name | ||
self.model_version = model_version | ||
self._api_key = api_key | ||
self._client = None | ||
|
||
@property | ||
def client(self): | ||
if not self._client: | ||
self._client = MistralAIClient(self.model_name, self._api_key) | ||
return self._client | ||
|
||
def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest: | ||
args = {"model": self.model_name, "messages": [{"role": _USER_ROLE, "content": prompt.text}]} | ||
if prompt.options.temperature is not None: | ||
args["temperature"] = prompt.options.temperature | ||
if prompt.options.max_tokens is not None: | ||
args["max_tokens"] = prompt.options.max_tokens | ||
return MistralAIRequest(**args) | ||
|
||
def evaluate(self, request: MistralAIRequest) -> ChatCompletionResponse: | ||
response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore | ||
return response | ||
|
||
def translate_response(self, request: MistralAIRequest, response: MistralAIResponse) -> SUTResponse: | ||
completions = [] | ||
for choice in response.choices: | ||
text = choice.message.content | ||
assert text is not None | ||
completions.append(SUTCompletion(text=str(text))) | ||
return SUTResponse(completions=completions) | ||
|
||
|
||
MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey) | ||
|
||
model_name = "ministral-8b" | ||
model_version = "2410" | ||
model_full_name = "ministral-8b-latest" # Mistral's endpoint schema | ||
model_uid = f"mistralai-{model_name}-{model_version}" | ||
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY) | ||
|
||
|
||
model_name = "mistral-large" | ||
model_version = "2411" | ||
model_full_name = "mistral-large-latest" # Mistral's endpoint schema | ||
model_uid = f"mistralai-{model_name}-{model_version}" | ||
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY) |
Oops, something went wrong.