Skip to content

Commit

Permalink
Mistral Client Migration (#707)
Browse files Browse the repository at this point in the history
Co-authored-by: b.nativi <[email protected]>
Co-authored-by: Charles Zaloom <[email protected]>
  • Loading branch information
3 people authored Aug 16, 2024
1 parent de16de0 commit 60622a3
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 129 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
args: [--line-length=79]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
rev: v1.1.376
hooks:
- id: pyright
additional_dependencies:
Expand All @@ -52,7 +52,7 @@ repos:
"psycopg2-binary",
"pgvector",
"openai",
"mistralai<=0.4.2",
"mistralai>=1.0",
"absl-py",
"nltk",
"rouge_score",
Expand Down
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"structlog",
"pgvector",
"openai",
"mistralai <= 0.4.2",
"mistralai >= 1.0",
"absl-py",
"nltk",
"rouge_score",
Expand Down
62 changes: 24 additions & 38 deletions api/tests/functional-tests/backend/core/test_llm_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from unittest.mock import MagicMock

import pytest
from mistralai.exceptions import MistralException
from mistralai.models.chat_completion import (
from mistralai.models import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
FinishReason,
UsageInfo,
)
from mistralai.models.common import UsageInfo
from mistralai.models.sdkerror import SDKError as MistralSDKError
from openai import OpenAIError
from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion import ChatCompletion, Choice
Expand Down Expand Up @@ -1023,11 +1022,6 @@ def _create_mock_chat_completion_none_content(
# Check that the WrappedOpenAIClient does not alter the messages.
assert fake_message == client._process_messages(fake_message)

# OpenAI only allows the roles of system, user and assistant.
invalid_message = [{"role": "invalid", "content": "Some content."}]
with pytest.raises(ValueError):
client._process_messages(invalid_message)

# The OpenAI Client should be able to connect if the API key is set as the environment variable.
os.environ["OPENAI_API_KEY"] = "dummy_key"
client = WrappedOpenAIClient(model_name="model_name")
Expand Down Expand Up @@ -1080,15 +1074,15 @@ def _create_mock_chat_completion_with_bad_length(
model="gpt-3.5-turbo",
object="chat.completion",
choices=[
ChatCompletionResponseChoice(
finish_reason=FinishReason("length"),
ChatCompletionChoice(
finish_reason="length",
index=0,
message=ChatMessage(
role="role",
content="some content",
name=None,
message=AssistantMessage(
role="assistant",
content="some response",
name=None, # type: ignore - mistralai issue
tool_calls=None,
tool_call_id=None,
tool_call_id=None, # type: ignore - mistralai issue
),
)
],
Expand All @@ -1106,15 +1100,15 @@ def _create_mock_chat_completion(
model="gpt-3.5-turbo",
object="chat.completion",
choices=[
ChatCompletionResponseChoice(
finish_reason=FinishReason("stop"),
ChatCompletionChoice(
finish_reason="stop",
index=0,
message=ChatMessage(
role="role",
message=AssistantMessage(
role="assistant",
content="some response",
name=None,
name=None, # type: ignore - mistralai issue
tool_calls=None,
tool_call_id=None,
tool_call_id=None, # type: ignore - mistralai issue
),
)
],
Expand All @@ -1128,20 +1122,12 @@ def _create_mock_chat_completion(
client = WrappedMistralAIClient(
api_key="invalid_key", model_name="model_name"
)
fake_message = [{"role": "role", "content": "content"}]
with pytest.raises(MistralException):
fake_message = [{"role": "assistant", "content": "content"}]
with pytest.raises(MistralSDKError):
client.connect()
client(fake_message)

assert [
ChatMessage(
role="role",
content="content",
name=None,
tool_calls=None,
tool_call_id=None,
)
] == client._process_messages(fake_message)
assert fake_message == client._process_messages(fake_message)

# The Mistral Client should be able to connect if the API key is set as the environment variable.
os.environ["MISTRAL_API_KEY"] = "dummy_key"
Expand All @@ -1151,18 +1137,18 @@ def _create_mock_chat_completion(
client.client = MagicMock()

# The metric computation should fail if the request fails.
client.client.chat = _create_bad_request
client.client.chat.complete = _create_bad_request
with pytest.raises(ValueError) as e:
client(fake_message)

# The metric computation should fail when the finish reason is bad length.
client.client.chat = _create_mock_chat_completion_with_bad_length
client.client.chat.complete = _create_mock_chat_completion_with_bad_length
with pytest.raises(ValueError) as e:
client(fake_message)
assert "reached max token limit" in str(e)

# The metric computation should run successfully when the finish reason is stop.
client.client.chat = _create_mock_chat_completion
client.client.chat.complete = _create_mock_chat_completion
assert client(fake_message) == "some response"


Expand Down
10 changes: 5 additions & 5 deletions api/tests/functional-tests/backend/metrics/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,11 +2292,11 @@ def test__convert_annotations_to_common_type(db: Session):
mask[ymin:ymax, xmin:xmax] = True

pts = [
(xmin, ymin),
(xmin, ymax),
(xmax, ymax),
(xmax, ymin),
(xmin, ymin),
(float(xmin), float(ymin)),
(float(xmin), float(ymax)),
(float(xmax), float(ymax)),
(float(xmax), float(ymin)),
(float(xmin), float(ymin)),
]
poly = schemas.Polygon(value=[pts])
raster = schemas.Raster.from_numpy(mask)
Expand Down
81 changes: 23 additions & 58 deletions api/valor_api/backend/core/llm_clients.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from typing import Any

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from openai import OpenAI as OpenAIClient
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from mistralai.sdk import Mistral
from openai import OpenAI
from pydantic import BaseModel

from valor_api.backend.core.llm_instructions_analysis import (
Expand Down Expand Up @@ -830,14 +823,14 @@ def connect(
Setup the connection to the API.
"""
if self.api_key is None:
self.client = OpenAIClient()
self.client = OpenAI()
else:
self.client = OpenAIClient(api_key=self.api_key)
self.client = OpenAI(api_key=self.api_key)

def _process_messages(
self,
messages: list[dict[str, str]],
) -> list[ChatCompletionMessageParam]:
) -> list[dict[str, str]]:
"""
Format messages for the API.
Expand All @@ -848,40 +841,13 @@ def _process_messages(
Returns
-------
list[ChatCompletionMessageParam]
The messages converted to the OpenAI client message objects.
list[dict[str, str]]
The messages are left in the OpenAI standard.
"""
# Validate that the input is a list of dictionaries with "role" and "content" keys.
_ = Messages(messages=messages) # type: ignore

ret = []
for i in range(len(messages)):
if messages[i]["role"] == "system":
ret.append(
ChatCompletionSystemMessageParam(
content=messages[i]["content"],
role="system",
)
)
elif messages[i]["role"] == "user":
ret.append(
ChatCompletionUserMessageParam(
content=messages[i]["content"],
role="user",
)
)
elif messages[i]["role"] == "assistant":
ret.append(
ChatCompletionAssistantMessageParam(
content=messages[i]["content"],
role="assistant",
)
)
else:
raise ValueError(
f"Role {messages[i]['role']} is not supported by OpenAI."
)
return ret
return messages

def __call__(
self,
Expand All @@ -903,7 +869,7 @@ def __call__(
processed_messages = self._process_messages(messages)
openai_response = self.client.chat.completions.create(
model=self.model_name,
messages=processed_messages,
messages=processed_messages, # type: ignore - mistralai issue
seed=self.seed,
)

Expand Down Expand Up @@ -965,9 +931,9 @@ def connect(
Setup the connection to the API.
"""
if self.api_key is None:
self.client = MistralClient()
self.client = Mistral()
else:
self.client = MistralClient(api_key=self.api_key)
self.client = Mistral(api_key=self.api_key)

def _process_messages(
self,
Expand All @@ -984,20 +950,12 @@ def _process_messages(
Returns
-------
Any
The messages formatted for Mistral's API. Each message is converted to a mistralai.models.chat_completion.ChatMessage object.
The messages formatted for Mistral's API. With mistralai>=1.0.0, the messages can be left in the OpenAI standard.
"""
# Validate that the input is a list of dictionaries with "role" and "content" keys.
_ = Messages(messages=messages) # type: ignore

ret = []
for i in range(len(messages)):
ret.append(
ChatMessage(
role=messages[i]["role"],
content=messages[i]["content"],
)
)
return ret
return messages

def __call__(
self,
Expand All @@ -1017,22 +975,29 @@ def __call__(
The response from the API.
"""
processed_messages = self._process_messages(messages)
mistral_response = self.client.chat(
mistral_response = self.client.chat.complete(
model=self.model_name,
messages=processed_messages,
)
if mistral_response is None or mistral_response.choices is None:
return ""

finish_reason = mistral_response.choices[
0
].finish_reason # Enum: "stop" "length" "model_length" "error" "tool_calls"
response = mistral_response.choices[0].message.content
finish_reason = mistral_response.choices[0].finish_reason
if mistral_response.choices[0].message is None:
response = ""
else:
response = mistral_response.choices[0].message.content

if finish_reason == "length":
raise ValueError(
"Mistral response reached max token limit. Resulting evaluation is likely invalid or of low quality."
)

if not isinstance(response, str):
raise TypeError("Mistral AI response was not a string.")

return response


Expand Down
Loading

0 comments on commit 60622a3

Please sign in to comment.