Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mistral large sut hosted on vertex ai #629

Merged
merged 24 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e805656
first pass at mistral large sut hosted on vertex ai
rogthefrog Oct 22, 2024
e1f9e01
Mistral SUT
rogthefrog Nov 19, 2024
c9749ce
minstral-8b-instruct is not available on vertex at the moment
rogthefrog Nov 19, 2024
ee564fd
not used
rogthefrog Nov 19, 2024
8b2183c
add missing google auth dependency
rogthefrog Nov 19, 2024
132d917
add region as secret; fix model UID to include the version
rogthefrog Nov 19, 2024
9496cd0
rename
rogthefrog Nov 19, 2024
59294ee
rename for clarity
rogthefrog Nov 19, 2024
5e9e7d8
noop; lint
rogthefrog Nov 19, 2024
d001bce
refactor to avoid conflicts with the other mistral SUT
rogthefrog Nov 19, 2024
e10a470
add ministral-8b sut
rogthefrog Nov 19, 2024
c450624
fix version string; switch mistral-large to MistralAI and disable it …
rogthefrog Nov 19, 2024
b739584
update mistral plugin name
rogthefrog Nov 19, 2024
5c8ae3e
remove unneeded mistral package
rogthefrog Nov 20, 2024
cd190fa
add temperature
rogthefrog Nov 20, 2024
d099fde
noop; add todo for later if we want to use vertex
rogthefrog Nov 20, 2024
8996ea1
add retry logic
rogthefrog Nov 20, 2024
6f5e745
add max_tokens; fix thresholds for exponential backoff
rogthefrog Nov 20, 2024
7288ed2
add max_tokens
rogthefrog Nov 20, 2024
68115cf
added max_tokens and temperature to test
rogthefrog Nov 20, 2024
fe5f136
noop; clean debug code
rogthefrog Nov 20, 2024
a055197
update uid and model name
rogthefrog Nov 21, 2024
b3ed7f3
update uid to keep this model distinct from the same model hosted by …
rogthefrog Nov 21, 2024
74997e8
update authentication for vertex AI, and add the vendor name to the S…
rogthefrog Nov 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions plugins/mistral/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Plugin for models hosted on MistralAI.
66 changes: 66 additions & 0 deletions plugins/mistral/modelgauge/suts/mistral_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from mistralai import Mistral

from mistralai.models import HTTPValidationError, SDKError
from mistralai.utils import BackoffStrategy, RetryConfig

from modelgauge.secret_values import RequiredSecret, SecretDescription


BACKOFF_INITIAL_MILLIS = 500
BACKOFF_MAX_INTERVAL_MILLIS = 10_000
BACKOFF_EXPONENT = 1.1
BACKOFF_MAX_ELAPSED_MILLIS = 60_000


class MistralAIAPIKey(RequiredSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
scope="mistralai",
key="api_key",
instructions="MistralAI API key. See https://docs.mistral.ai/getting-started/quickstart/",
)


class MistralAIClient:
def __init__(
self,
model_name: str,
api_key: MistralAIAPIKey,
):
self.model_name = model_name
self.api_key = api_key.value
self._client = None

@property
def client(self) -> Mistral:
if not self._client:
self._client = Mistral(
api_key=self.api_key,
retry_config=RetryConfig(
"backoff",
BackoffStrategy(
BACKOFF_INITIAL_MILLIS,
BACKOFF_MAX_INTERVAL_MILLIS,
BACKOFF_EXPONENT,
BACKOFF_MAX_INTERVAL_MILLIS,
),
True,
),
)
return self._client

def request(self, req: dict):
response = None
try:
response = self.client.chat.complete(**req)
return response
# TODO check if this actually happens
except HTTPValidationError as exc:
raise (exc)
# TODO check if the retry strategy takes care of this
except SDKError as exc:
raise (exc)
# TODO what else can happen?
except Exception as exc:
raise (exc)
89 changes: 89 additions & 0 deletions plugins/mistral/modelgauge/suts/mistral_sut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Optional

from mistralai.models import ChatCompletionResponse
from modelgauge.prompt import TextPrompt
from modelgauge.secret_values import InjectSecret
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.suts.mistral_client import MistralAIAPIKey, MistralAIClient

from pydantic import BaseModel

_USER_ROLE = "user"


class MistralAIRequest(BaseModel):
model: str
messages: list[dict]
temperature: Optional[float] = None
max_tokens: Optional[int]


class MistralAIResponse(ChatCompletionResponse):
"""The ChatCompletionResponse class from Mistral matches our Response
objects now, but we subclass it for consistency and so we can adjust it
in case the upstream object changes."""

pass


@modelgauge_sut(capabilities=[AcceptsTextPrompt])
class MistralAISut(PromptResponseSUT):
"""A MistralAI SUT hosted on MistralAI."""

def __init__(
self,
uid: str,
model_name: str,
model_version: str,
api_key: MistralAIAPIKey,
):
super().__init__(uid)
self.model_name = model_name
self.model_version = model_version
self._api_key = api_key
self._client = None

@property
def client(self):
if not self._client:
self._client = MistralAIClient(self.model_name, self._api_key)
return self._client

def translate_text_prompt(self, prompt: TextPrompt) -> MistralAIRequest:
args = {"model": self.model_name, "messages": [{"role": _USER_ROLE, "content": prompt.text}]}
if prompt.options.temperature is not None:
args["temperature"] = prompt.options.temperature
if prompt.options.max_tokens is not None:
args["max_tokens"] = prompt.options.max_tokens
return MistralAIRequest(**args)

def evaluate(self, request: MistralAIRequest) -> ChatCompletionResponse:
response = self.client.request(request.model_dump(exclude_none=True)) # type: ignore
return response

def translate_response(self, request: MistralAIRequest, response: MistralAIResponse) -> SUTResponse:
completions = []
for choice in response.choices:
text = choice.message.content
assert text is not None
completions.append(SUTCompletion(text=str(text)))
return SUTResponse(completions=completions)


MISTRAL_API_KEY = InjectSecret(MistralAIAPIKey)

model_name = "ministral-8b"
model_version = "2410"
model_full_name = "ministral-8b-latest" # Mistral's endpoint schema
model_uid = f"mistralai-{model_name}-{model_version}"
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY)


model_name = "mistral-large"
model_version = "2411"
model_full_name = "mistral-large-latest" # Mistral's endpoint schema
model_uid = f"mistralai-{model_name}-{model_version}"
SUTS.register(MistralAISut, model_uid, model_full_name, model_version, MISTRAL_API_KEY)
Loading