Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral Client #2892

Merged
merged 28 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c1c24ff
Initial commit of Mistral client class
marklysze May 23, 2024
5e25c7f
Updated to manage final system message for reflection_with_llm
marklysze Jun 7, 2024
d3f4783
Add Mistral support to client class
marklysze Jun 7, 2024
322c402
Add Mistral support across the board (based on Gemini changes)
marklysze Jun 7, 2024
d88b6e8
Test file for Mistral client
marklysze Jun 8, 2024
39b7c1e
Merge remote-tracking branch 'origin/main' into mistral_client
marklysze Jun 8, 2024
ddfe25a
Updated handling of config, added notebook for documentation
marklysze Jun 9, 2024
fcab23b
Added support for additional API parameters
marklysze Jun 11, 2024
691c38c
Remove unneeded code, updated exception raising
marklysze Jun 11, 2024
4d6b0c8
Updated handling of keywords, including type checks, defaults, warnin…
marklysze Jun 11, 2024
ad3033a
Added class description.
marklysze Jun 12, 2024
a8bb96f
Updated tests to support new config handling.
marklysze Jun 12, 2024
3acb446
Moved parameter parsing to create function, minimised init, added par…
marklysze Jun 13, 2024
b988eec
Refined parameter validation
marklysze Jun 14, 2024
172ce3f
Correct spacing
marklysze Jun 14, 2024
8840963
Fixed string concat in parameter validation
marklysze Jun 14, 2024
fef5a1a
Corrected upper/lower bound warning
marklysze Jun 14, 2024
07f1d5a
Merge remote-tracking branch 'origin/main' into mistral_client
marklysze Jun 15, 2024
3d0fc76
Merge remote-tracking branch 'origin/main' into mistral_client
marklysze Jun 19, 2024
dd77683
Use client_tools, tidy up Mistral create, better handle tool call res…
marklysze Jun 19, 2024
3c8f27b
Update of documentation notebook, replacement of old version
marklysze Jun 19, 2024
8454762
Update to handle multiple tool_call recommendations in a message
marklysze Jun 19, 2024
d4a9186
Merge remote-tracking branch 'origin/main' into mistral_client
marklysze Jun 20, 2024
cc8ebe4
Updated tests to accommodate multiple tool_calls as well as content i…
marklysze Jun 20, 2024
0a2a690
Update autogen/oai/mistral.py comment
marklysze Jun 20, 2024
7ce87b4
cleanup, rewrite mock
yiranwu0 Jun 21, 2024
1a88d22
Merge remote-tracking branch 'origin/main' into mistral_client
yiranwu0 Jun 21, 2024
3aab63d
update
yiranwu0 Jun 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,43 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

MistralTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.9", "3.10", "3.11", "3.12"]
exclude:
- os: macos-latest
python-version: "3.9"
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest-cov>=5
- name: Install packages and dependencies for Mistral
run: |
pip install -e .[mistral,test]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pytest test/oai/test_mistral.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
3 changes: 2 additions & 1 deletion autogen/logger/file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.mistral import MistralAIClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -202,7 +203,7 @@ def log_new_wrapper(

def log_new_client(
self,
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient,
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient,
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
Expand Down
3 changes: 2 additions & 1 deletion autogen/logger/sqlite_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.mistral import MistralAIClient

logger = logging.getLogger(__name__)
lock = threading.Lock()
Expand Down Expand Up @@ -389,7 +390,7 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st

def log_new_client(
self,
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient],
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
Expand Down
12 changes: 12 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@
except ImportError as e:
anthropic_import_exception = e

try:
from autogen.oai.mistral import MistralAIClient

mistral_import_exception: Optional[ImportError] = None
except ImportError as e:
mistral_import_exception = e

logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
Expand Down Expand Up @@ -461,6 +468,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
raise ImportError("Please install `anthropic` to use Anthropic API.")
client = AnthropicClient(**openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("mistral"):
if mistral_import_exception:
raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
client = MistralAIClient(**openai_config)
self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
Expand Down
5 changes: 3 additions & 2 deletions autogen/oai/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def should_hide_tools(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]
# Loop through the messages and check if the tools have been run, removing them as we go
for message in messages:
if "tool_calls" in message:
# Register the tool id and the name
tool_call_ids[message["tool_calls"][0]["id"]] = message["tool_calls"][0]["function"]["name"]
# Register the tool ids and the function names (there could be multiple tool calls)
for tool_call in message["tool_calls"]:
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
elif "tool_call_id" in message:
# Tool called, get the name of the function based on the id
tool_name_called = tool_call_ids[message["tool_call_id"]]
Expand Down
227 changes: 227 additions & 0 deletions autogen/oai/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""Create an OpenAI-compatible client using Mistral.AI's API.

Example:
llm_config={
"config_list": [{
"api_type": "mistral",
"model": "open-mixtral-8x22b",
"api_key": os.environ.get("MISTRAL_API_KEY")
}
]}

agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)

Install Mistral.AI python library using: pip install --upgrade mistralai

Resources:
- https://docs.mistral.ai/getting-started/quickstart/
"""

# Important notes when using the Mistral.AI API:
# The first system message can greatly affect whether the model returns a tool call, including text that references the ability to use functions will help.
# Changing the role on the first system message to 'user' improved the chances of the model recommending a tool call.

import inspect
import json
import os
import time
import warnings
from typing import Any, Dict, List, Tuple, Union

# Mistral libraries
# pip install mistralai
from mistralai.client import MistralClient
from mistralai.exceptions import MistralAPIException
from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage, ToolCall
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from typing_extensions import Annotated

from autogen.oai.client_utils import should_hide_tools, validate_parameter


class MistralAIClient:
"""Client for Mistral.AI's API."""

def __init__(self, **kwargs):
"""Requires api_key or environment variable to be set

Args:
api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set)
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("MISTRAL_API_KEY", None)

assert (
self.api_key
), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."

def message_retrieval(self, response: ChatCompletionResponse) -> Union[List[str], List[ChatCompletionMessage]]:
"""Retrieve the messages from the response."""

return [choice.message for choice in response.choices]

def cost(self, response) -> float:
return response.cost

def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
mistral_params = {}

# 1. Validate models
mistral_params["model"] = params.get("model", None)
assert mistral_params[
"model"
], "Please specify the 'model' in your config list entry to nominate the Mistral.ai model to use."

# 2. Validate allowed Mistral.AI parameters
mistral_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 0.7, None, None)
mistral_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
mistral_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
mistral_params["safe_prompt"] = validate_parameter(
params, "safe_prompt", bool, False, False, None, [True, False]
)
mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)

# 3. Convert messages to Mistral format
mistral_messages = []
tool_call_ids = {} # tool call ids to function name mapping
for message in params["messages"]:
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
# Convert OAI ToolCall to Mistral ToolCall
openai_toolcalls = message["tool_calls"]
mistral_toolcalls = []
for toolcall in openai_toolcalls:
mistral_toolcall = ToolCall(id=toolcall["id"], function=toolcall["function"])
mistral_toolcalls.append(mistral_toolcall)
mistral_messages.append(
ChatMessage(role=message["role"], content=message["content"], tool_calls=mistral_toolcalls)
)

# Map tool call id to the function name
for tool_call in message["tool_calls"]:
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]

elif message["role"] in ("system", "user", "assistant"):
# Note this ChatMessage can take a 'name' but it is rejected by the Mistral API if not role=tool, so, no, the 'name' field is not used.
mistral_messages.append(ChatMessage(role=message["role"], content=message["content"]))

elif message["role"] == "tool":
# Indicates the result of a tool call, the name is the function name called
mistral_messages.append(
ChatMessage(
role="tool",
name=tool_call_ids[message["tool_call_id"]],
content=message["content"],
tool_call_id=message["tool_call_id"],
)
)
else:
warnings.warn(f"Unknown message role {message['role']}", UserWarning)

# If a 'system' message follows an 'assistant' message, change it to 'user'
# This can occur when using LLM summarisation
for i in range(1, len(mistral_messages)):
if mistral_messages[i - 1].role == "assistant" and mistral_messages[i].role == "system":
mistral_messages[i].role = "user"
marklysze marked this conversation as resolved.
Show resolved Hide resolved

mistral_params["messages"] = mistral_messages

# 4. Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(params["messages"], params["tools"], hide_tools):
mistral_params["tools"] = params["tools"]
return mistral_params

def create(self, params: Dict[str, Any]) -> ChatCompletion:
# 1. Parse parameters to Mistral.AI API's parameters
mistral_params = self.parse_params(params)

# 2. Call Mistral.AI API
client = MistralClient(api_key=self.api_key)
mistral_response = client.chat(**mistral_params)
# TODO: Handle streaming

# 3. Convert Mistral response to OAI compatible format
if mistral_response.choices[0].finish_reason == "tool_calls":
mistral_finish = "tool_calls"
tool_calls = []
for tool_call in mistral_response.choices[0].message.tool_calls:
tool_calls.append(
ChatCompletionMessageToolCall(
id=tool_call.id,
function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
type="function",
)
)
else:
mistral_finish = "stop"
tool_calls = None

message = ChatCompletionMessage(
role="assistant",
content=mistral_response.choices[0].message.content,
function_call=None,
tool_calls=tool_calls,
)
choices = [Choice(finish_reason=mistral_finish, index=0, message=message)]

response_oai = ChatCompletion(
id=mistral_response.id,
model=mistral_response.model,
created=int(time.time() * 1000),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=mistral_response.usage.prompt_tokens,
completion_tokens=mistral_response.usage.completion_tokens,
total_tokens=mistral_response.usage.prompt_tokens + mistral_response.usage.completion_tokens,
),
cost=calculate_mistral_cost(
mistral_response.usage.prompt_tokens, mistral_response.usage.completion_tokens, mistral_response.model
),
)

return response_oai

@staticmethod
def get_usage(response: ChatCompletionResponse) -> Dict:
return {
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
"total_tokens": (
response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
),
"cost": response.cost if hasattr(response, "cost") else 0,
"model": response.model,
}


def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
"""Calculate the cost of the mistral response."""

# Prices per 1 million tokens
# https://mistral.ai/technology/
model_cost_map = {
"open-mistral-7b": {"input": 0.25, "output": 0.25},
"open-mixtral-8x7b": {"input": 0.7, "output": 0.7},
"open-mixtral-8x22b": {"input": 2.0, "output": 6.0},
"mistral-small-latest": {"input": 1.0, "output": 3.0},
"mistral-medium-latest": {"input": 2.7, "output": 8.1},
"mistral-large-latest": {"input": 4.0, "output": 12.0},
}

# Ensure we have the model they are using and return the total cost
if model_name in model_cost_map:
costs = model_cost_map[model_name]

return (input_tokens * costs["input"] / 1_000_000) + (output_tokens * costs["output"] / 1_000_000)
else:
warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
return 0
5 changes: 4 additions & 1 deletion autogen/runtime_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.mistral import MistralAIClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +109,9 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig


def log_new_client(
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_client: autogen logger is None")
Expand Down
3 changes: 3 additions & 0 deletions autogen/token_count_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
elif "claude" in model:
logger.info("Claude is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
elif "mistral-" in model or "mixtral-" in model:
logger.info("Mistral.AI models are not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
"long-context": ["llmlingua<0.3"],
"anthropic": ["anthropic>=0.23.1"],
"mistral": ["mistralai>=0.2.0"],
}

setuptools.setup(
Expand Down
Loading
Loading