diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml
index 05e7349d111..0535aa25f3b 100644
--- a/.github/workflows/contrib-tests.yml
+++ b/.github/workflows/contrib-tests.yml
@@ -518,3 +518,43 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
+
+ MistralTest:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-2019]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.9"
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Mistral
+ run: |
+ pip install -e .[mistral,test]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/oai/test_mistral.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py
index 4e151347837..3b97bf04aff 100644
--- a/autogen/logger/file_logger.py
+++ b/autogen/logger/file_logger.py
@@ -19,6 +19,7 @@
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.mistral import MistralAIClient
logger = logging.getLogger(__name__)
@@ -202,7 +203,7 @@ def log_new_wrapper(
def log_new_client(
self,
- client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient,
+ client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient,
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py
index 1e80bc8751e..6f80c86a3dc 100644
--- a/autogen/logger/sqlite_logger.py
+++ b/autogen/logger/sqlite_logger.py
@@ -20,6 +20,7 @@
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.mistral import MistralAIClient
logger = logging.getLogger(__name__)
lock = threading.Lock()
@@ -389,7 +390,7 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st
def log_new_client(
self,
- client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient],
+ client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
diff --git a/autogen/oai/client.py b/autogen/oai/client.py
index e86ecc51282..87c22954174 100644
--- a/autogen/oai/client.py
+++ b/autogen/oai/client.py
@@ -56,6 +56,13 @@
except ImportError as e:
anthropic_import_exception = e
+try:
+ from autogen.oai.mistral import MistralAIClient
+
+ mistral_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ mistral_import_exception = e
+
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@@ -461,6 +468,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
raise ImportError("Please install `anthropic` to use Anthropic API.")
client = AnthropicClient(**openai_config)
self._clients.append(client)
+ elif api_type is not None and api_type.startswith("mistral"):
+ if mistral_import_exception:
+ raise ImportError("Please install `mistralai` to use the Mistral.AI API.")
+ client = MistralAIClient(**openai_config)
+ self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
diff --git a/autogen/oai/client_utils.py b/autogen/oai/client_utils.py
index 143168d5d9f..55730485b40 100644
--- a/autogen/oai/client_utils.py
+++ b/autogen/oai/client_utils.py
@@ -135,8 +135,9 @@ def should_hide_tools(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]
# Loop through the messages and check if the tools have been run, removing them as we go
for message in messages:
if "tool_calls" in message:
- # Register the tool id and the name
- tool_call_ids[message["tool_calls"][0]["id"]] = message["tool_calls"][0]["function"]["name"]
+ # Register the tool ids and the function names (there could be multiple tool calls)
+ for tool_call in message["tool_calls"]:
+ tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
elif "tool_call_id" in message:
# Tool called, get the name of the function based on the id
tool_name_called = tool_call_ids[message["tool_call_id"]]
diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py
new file mode 100644
index 00000000000..832369376af
--- /dev/null
+++ b/autogen/oai/mistral.py
@@ -0,0 +1,227 @@
+"""Create an OpenAI-compatible client using Mistral.AI's API.
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "mistral",
+ "model": "open-mixtral-8x22b",
+ "api_key": os.environ.get("MISTRAL_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Install Mistral.AI python library using: pip install --upgrade mistralai
+
+Resources:
+- https://docs.mistral.ai/getting-started/quickstart/
+"""
+
+# Important notes when using the Mistral.AI API:
+# The first system message can greatly affect whether the model returns a tool call, including text that references the ability to use functions will help.
+# Changing the role on the first system message to 'user' improved the chances of the model recommending a tool call.
+
+import inspect
+import json
+import os
+import time
+import warnings
+from typing import Any, Dict, List, Tuple, Union
+
+# Mistral libraries
+# pip install mistralai
+from mistralai.client import MistralClient
+from mistralai.exceptions import MistralAPIException
+from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage, ToolCall
+from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+from typing_extensions import Annotated
+
+from autogen.oai.client_utils import should_hide_tools, validate_parameter
+
+
+class MistralAIClient:
+ """Client for Mistral.AI's API."""
+
+ def __init__(self, **kwargs):
+ """Requires api_key or environment variable to be set
+
+ Args:
+ api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set)
+ """
+ # Ensure we have the api_key upon instantiation
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("MISTRAL_API_KEY", None)
+
+ assert (
+ self.api_key
+ ), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
+
+ def message_retrieval(self, response: ChatCompletionResponse) -> Union[List[str], List[ChatCompletionMessage]]:
+ """Retrieve the messages from the response."""
+
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Loads the parameters for Mistral.AI API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults"""
+ mistral_params = {}
+
+ # 1. Validate models
+ mistral_params["model"] = params.get("model", None)
+ assert mistral_params[
+ "model"
+ ], "Please specify the 'model' in your config list entry to nominate the Mistral.ai model to use."
+
+ # 2. Validate allowed Mistral.AI parameters
+ mistral_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 0.7, None, None)
+ mistral_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None)
+ mistral_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None)
+ mistral_params["safe_prompt"] = validate_parameter(
+ params, "safe_prompt", bool, False, False, None, [True, False]
+ )
+ mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)
+
+ # 3. Convert messages to Mistral format
+ mistral_messages = []
+ tool_call_ids = {} # tool call ids to function name mapping
+ for message in params["messages"]:
+ if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
+ # Convert OAI ToolCall to Mistral ToolCall
+ openai_toolcalls = message["tool_calls"]
+ mistral_toolcalls = []
+ for toolcall in openai_toolcalls:
+ mistral_toolcall = ToolCall(id=toolcall["id"], function=toolcall["function"])
+ mistral_toolcalls.append(mistral_toolcall)
+ mistral_messages.append(
+ ChatMessage(role=message["role"], content=message["content"], tool_calls=mistral_toolcalls)
+ )
+
+ # Map tool call id to the function name
+ for tool_call in message["tool_calls"]:
+ tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]
+
+ elif message["role"] in ("system", "user", "assistant"):
+ # Note this ChatMessage can take a 'name' but it is rejected by the Mistral API if not role=tool, so, no, the 'name' field is not used.
+ mistral_messages.append(ChatMessage(role=message["role"], content=message["content"]))
+
+ elif message["role"] == "tool":
+ # Indicates the result of a tool call, the name is the function name called
+ mistral_messages.append(
+ ChatMessage(
+ role="tool",
+ name=tool_call_ids[message["tool_call_id"]],
+ content=message["content"],
+ tool_call_id=message["tool_call_id"],
+ )
+ )
+ else:
+ warnings.warn(f"Unknown message role {message['role']}", UserWarning)
+
+ # If a 'system' message follows an 'assistant' message, change it to 'user'
+ # This can occur when using LLM summarisation
+ for i in range(1, len(mistral_messages)):
+ if mistral_messages[i - 1].role == "assistant" and mistral_messages[i].role == "system":
+ mistral_messages[i].role = "user"
+
+ mistral_params["messages"] = mistral_messages
+
+ # 4. Add tools to the call if we have them and aren't hiding them
+ if "tools" in params:
+ hide_tools = validate_parameter(
+ params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
+ )
+ if not should_hide_tools(params["messages"], params["tools"], hide_tools):
+ mistral_params["tools"] = params["tools"]
+ return mistral_params
+
+ def create(self, params: Dict[str, Any]) -> ChatCompletion:
+ # 1. Parse parameters to Mistral.AI API's parameters
+ mistral_params = self.parse_params(params)
+
+ # 2. Call Mistral.AI API
+ client = MistralClient(api_key=self.api_key)
+ mistral_response = client.chat(**mistral_params)
+ # TODO: Handle streaming
+
+ # 3. Convert Mistral response to OAI compatible format
+ if mistral_response.choices[0].finish_reason == "tool_calls":
+ mistral_finish = "tool_calls"
+ tool_calls = []
+ for tool_call in mistral_response.choices[0].message.tool_calls:
+ tool_calls.append(
+ ChatCompletionMessageToolCall(
+ id=tool_call.id,
+ function={"name": tool_call.function.name, "arguments": tool_call.function.arguments},
+ type="function",
+ )
+ )
+ else:
+ mistral_finish = "stop"
+ tool_calls = None
+
+ message = ChatCompletionMessage(
+ role="assistant",
+ content=mistral_response.choices[0].message.content,
+ function_call=None,
+ tool_calls=tool_calls,
+ )
+ choices = [Choice(finish_reason=mistral_finish, index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=mistral_response.id,
+ model=mistral_response.model,
+ created=int(time.time() * 1000),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=mistral_response.usage.prompt_tokens,
+ completion_tokens=mistral_response.usage.completion_tokens,
+ total_tokens=mistral_response.usage.prompt_tokens + mistral_response.usage.completion_tokens,
+ ),
+ cost=calculate_mistral_cost(
+ mistral_response.usage.prompt_tokens, mistral_response.usage.completion_tokens, mistral_response.model
+ ),
+ )
+
+ return response_oai
+
+ @staticmethod
+ def get_usage(response: ChatCompletionResponse) -> Dict:
+ return {
+ "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
+ "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
+ "total_tokens": (
+ response.usage.prompt_tokens + response.usage.completion_tokens if response.usage is not None else 0
+ ),
+ "cost": response.cost if hasattr(response, "cost") else 0,
+ "model": response.model,
+ }
+
+
+def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
+ """Calculate the cost of the mistral response."""
+
+ # Prices per 1 million tokens
+ # https://mistral.ai/technology/
+ model_cost_map = {
+ "open-mistral-7b": {"input": 0.25, "output": 0.25},
+ "open-mixtral-8x7b": {"input": 0.7, "output": 0.7},
+ "open-mixtral-8x22b": {"input": 2.0, "output": 6.0},
+ "mistral-small-latest": {"input": 1.0, "output": 3.0},
+ "mistral-medium-latest": {"input": 2.7, "output": 8.1},
+ "mistral-large-latest": {"input": 4.0, "output": 12.0},
+ }
+
+ # Ensure we have the model they are using and return the total cost
+ if model_name in model_cost_map:
+ costs = model_cost_map[model_name]
+
+ return (input_tokens * costs["input"] / 1_000_000) + (output_tokens * costs["output"] / 1_000_000)
+ else:
+ warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
+ return 0
diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py
index 3d238580bb0..0fe7e8d8b86 100644
--- a/autogen/runtime_logging.py
+++ b/autogen/runtime_logging.py
@@ -15,6 +15,7 @@
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.gemini import GeminiClient
+ from autogen.oai.mistral import MistralAIClient
logger = logging.getLogger(__name__)
@@ -108,7 +109,9 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig
def log_new_client(
- client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient], wrapper: OpenAIWrapper, init_args: Dict[str, Any]
+ client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient],
+ wrapper: OpenAIWrapper,
+ init_args: Dict[str, Any],
) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_client: autogen logger is None")
diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py
index 6da6429d7e7..2842a749453 100644
--- a/autogen/token_count_utils.py
+++ b/autogen/token_count_utils.py
@@ -122,6 +122,9 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
elif "claude" in model:
logger.info("Claude is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
+ elif "mistral-" in model or "mixtral-" in model:
+ logger.info("Mistral.AI models are not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
+ return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
diff --git a/setup.py b/setup.py
index 20dc5e1bf70..346166270ab 100644
--- a/setup.py
+++ b/setup.py
@@ -89,6 +89,7 @@
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
"long-context": ["llmlingua<0.3"],
"anthropic": ["anthropic>=0.23.1"],
+ "mistral": ["mistralai>=0.2.0"],
}
setuptools.setup(
diff --git a/test/oai/test_mistral.py b/test/oai/test_mistral.py
new file mode 100644
index 00000000000..5236f71d7b7
--- /dev/null
+++ b/test/oai/test_mistral.py
@@ -0,0 +1,170 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+try:
+ from mistralai.models.chat_completion import ChatMessage
+
+ from autogen.oai.mistral import MistralAIClient, calculate_mistral_cost
+
+ skip = False
+except ImportError:
+ MistralAIClient = object
+ InternalServerError = object
+ skip = True
+
+
+# Fixtures for mock data
+@pytest.fixture
+def mock_response():
+ class MockResponse:
+ def __init__(self, text, choices, usage, cost, model):
+ self.text = text
+ self.choices = choices
+ self.usage = usage
+ self.cost = cost
+ self.model = model
+
+ return MockResponse
+
+
+@pytest.fixture
+def mistral_client():
+ return MistralAIClient(api_key="fake_api_key")
+
+
+# Test initialization and configuration
+@pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed")
+def test_initialization():
+
+ # Missing any api_key
+ with pytest.raises(AssertionError) as assertinfo:
+ MistralAIClient() # Should raise an AssertionError due to missing api_key
+
+ assert (
+ "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."
+ in str(assertinfo.value)
+ )
+
+ # Creation works
+ MistralAIClient(api_key="fake_api_key") # Should create okay now.
+
+
+# Test standard initialization
+@pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed")
+def test_valid_initialization(mistral_client):
+ assert mistral_client.api_key == "fake_api_key", "Config api_key should be correctly set"
+
+
+# Test cost calculation
+@pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed")
+def test_cost_calculation(mock_response):
+ response = mock_response(
+ text="Example response",
+ choices=[{"message": "Test message 1"}],
+ usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
+ cost=None,
+ model="mistral-large-latest",
+ )
+ assert (
+ calculate_mistral_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model)
+ == 0.0001
+ ), "Cost for this should be $0.0001"
+
+
+# Test text generation
+@pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed")
+@patch("autogen.oai.mistral.MistralClient.chat")
+def test_create_response(mock_chat, mistral_client):
+ # Mock MistralClient.chat response
+ mock_mistral_response = MagicMock()
+ mock_mistral_response.choices = [
+ MagicMock(finish_reason="stop", message=MagicMock(content="Example Mistral response", tool_calls=None))
+ ]
+ mock_mistral_response.id = "mock_mistral_response_id"
+ mock_mistral_response.model = "mistral-small-latest"
+ mock_mistral_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage
+
+ mock_chat.return_value = mock_mistral_response
+
+ # Test parameters
+ params = {
+ "messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}],
+ "model": "mistral-small-latest",
+ }
+
+ # Call the create method
+ response = mistral_client.create(params)
+
+ # Assertions to check if response is structured as expected
+ assert (
+ response.choices[0].message.content == "Example Mistral response"
+ ), "Response content should match expected output"
+ assert response.id == "mock_mistral_response_id", "Response ID should match the mocked response ID"
+ assert response.model == "mistral-small-latest", "Response model should match the mocked response model"
+ assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage"
+ assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage"
+
+
+# Test functions/tools
+@pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed")
+@patch("autogen.oai.mistral.MistralClient.chat")
+def test_create_response_with_tool_call(mock_chat, mistral_client):
+ # Mock `mistral_response = client.chat(**mistral_params)`
+ mock_function = MagicMock(name="currency_calculator")
+ mock_function.name = "currency_calculator"
+ mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}'
+
+ mock_function_2 = MagicMock(name="get_weather")
+ mock_function_2.name = "get_weather"
+ mock_function_2.arguments = '{"location": "Chicago"}'
+
+ mock_chat.return_value = MagicMock(
+ choices=[
+ MagicMock(
+ finish_reason="tool_calls",
+ message=MagicMock(
+ content="Sample text about the functions",
+ tool_calls=[
+ MagicMock(id="gdRdrvnHh", function=mock_function),
+ MagicMock(id="abRdrvnHh", function=mock_function_2),
+ ],
+ ),
+ )
+ ],
+ id="mock_mistral_response_id",
+ model="mistral-small-latest",
+ usage=MagicMock(prompt_tokens=10, completion_tokens=20),
+ )
+
+ # Construct parameters
+ converted_functions = [
+ {
+ "type": "function",
+ "function": {
+ "description": "Currency exchange calculator.",
+ "name": "currency_calculator",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "base_amount": {"type": "number", "description": "Amount of currency in base_currency"},
+ },
+ "required": ["base_amount"],
+ },
+ },
+ }
+ ]
+ mistral_messages = [
+ {"role": "user", "content": "How much is 123.45 EUR in USD?"},
+ {"role": "assistant", "content": "World"},
+ ]
+
+ # Call the create method
+ response = mistral_client.create(
+ {"messages": mistral_messages, "tools": converted_functions, "model": "mistral-medium-latest"}
+ )
+
+ # Assertions to check if the functions and content are included in the response
+ assert response.choices[0].message.content == "Sample text about the functions"
+ assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator"
+ assert response.choices[0].message.tool_calls[1].function.name == "get_weather"
diff --git a/website/docs/topics/non-openai-models/cloud-mistralai.ipynb b/website/docs/topics/non-openai-models/cloud-mistralai.ipynb
index a95e0bc514f..1228f96db4e 100644
--- a/website/docs/topics/non-openai-models/cloud-mistralai.ipynb
+++ b/website/docs/topics/non-openai-models/cloud-mistralai.ipynb
@@ -6,37 +6,150 @@
"source": [
"# Mistral AI\n",
"\n",
- "[Mistral AI](https://mistral.ai/) is a cloud based platform\n",
- "serving Mistral's own LLMs.\n",
- "You can use AutoGen with Mistral AI's API directly."
+ "[Mistral AI](https://mistral.ai/) is a cloud based platform serving their own LLMs, like Mistral, Mixtral, and Codestral.\n",
+ "\n",
+ "Although AutoGen can be used with Mistral AI's API directly by changing the `base_url` to their url, it does not cater for some differences between messaging and, with their API being more strict than OpenAI's, it is recommended to use the Mistral AI Client class as shown in this notebook.\n",
+ "\n",
+ "You will need a Mistral.AI account and create an API key. [See their website for further details](https://mistral.ai/)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "First you need to install the `pyautogen` package to use AutoGen."
+ "## Features\n",
+ "\n",
+ "When using this client class, messages are automatically tailored to accommodate the specific requirements of Mistral AI's API (such as role orders), which have become more strict than OpenAI's API.\n",
+ "\n",
+ "Additionally, this client class provides support for function/tool calling and will track token usage and cost correctly as per Mistral AI's API costs (as of June 2024).\n",
+ "\n",
+ "## Getting started\n",
+ "\n",
+ "First you need to install the `pyautogen` package to use AutoGen with the Mistral API library.\n",
+ "\n",
+ "``` bash\n",
+ "pip install pyautogen[mistral]\n",
+ "```"
]
},
{
- "cell_type": "code",
- "execution_count": null,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "! pip install pyautogen"
+ "Mistral provides a number of models to use, included below. See the list of [models here](https://docs.mistral.ai/platform/endpoints/).\n",
+ "\n",
+ "See the sample `OAI_CONFIG_LIST` below showing how the Mistral AI client class is used by specifying the `api_type` as `mistral`.\n",
+ "\n",
+ "```python\n",
+ "[\n",
+ " {\n",
+ " \"model\": \"gpt-35-turbo\",\n",
+ " \"api_key\": \"your OpenAI Key goes here\",\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"gpt-4-vision-preview\",\n",
+ " \"api_key\": \"your OpenAI Key goes here\",\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"dalle\",\n",
+ " \"api_key\": \"your OpenAI Key goes here\",\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"open-mistral-7b\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\"\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"open-mixtral-8x7b\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\"\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"open-mixtral-8x22b\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\"\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"mistral-small-latest\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\"\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"mistral-medium-latest\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\"\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"mistral-large-latest\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\"\n",
+ " },\n",
+ " {\n",
+ " \"model\": \"codestral-latest\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\"\n",
+ " }\n",
+ "]\n",
+ "```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Now you can set up the Mistral model you want to use. See the list of [models here](https://docs.mistral.ai/platform/endpoints/)."
+ "As an alternative to the `api_key` key and value in the config, you can set the environment variable `MISTRAL_API_KEY` to your Mistral AI key.\n",
+ "\n",
+ "Linux/Mac:\n",
+ "``` bash\n",
+ "export MISTRAL_API_KEY=\"your_mistral_ai_api_key_here\"\n",
+ "```\n",
+ "\n",
+ "Windows:\n",
+ "``` bash\n",
+ "set MISTRAL_API_KEY=your_mistral_ai_api_key_here\n",
+ "```\n",
+ "\n",
+ "## API parameters\n",
+ "\n",
+ "The following parameters can be added to your config for the Mistral.AI API. See [this link](https://docs.mistral.ai/api/#operation/createChatCompletion) for further information on them and their default values.\n",
+ "\n",
+ "- temperature (number 0..1)\n",
+ "- top_p (number 0..1)\n",
+ "- max_tokens (null, integer >= 0)\n",
+ "- random_seed (null, integer)\n",
+ "- safe_prompt (True or False)\n",
+ "\n",
+ "Example:\n",
+ "```python\n",
+ "[\n",
+ " {\n",
+ " \"model\": \"codestral-latest\",\n",
+ " \"api_key\": \"your Mistral AI API Key goes here\",\n",
+ " \"api_type\": \"mistral\",\n",
+ " \"temperature\": 0.5,\n",
+ " \"top_p\": 0.2, # Note: It is recommended to set temperature or top_p but not both.\n",
+ " \"max_tokens\": 10000,\n",
+ " \"safe_prompt\": False,\n",
+ " \"random_seed\": 42\n",
+ " }\n",
+ "]\n",
+ "```\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Two-Agent Coding Example\n",
+ "\n",
+ "In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n",
+ "\n",
+ "We'll use Mistral's Mixtral 8x22B model which is suitable for coding."
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -44,11 +157,12 @@
"\n",
"config_list = [\n",
" {\n",
- " # Choose your model name.\n",
- " \"model\": \"mistral-large-latest\",\n",
- " \"base_url\": \"https://api.mistral.ai/v1\",\n",
- " # You need to provide your API key here.\n",
+ " # Let's choose the Mixtral 8x22B model\n",
+ " \"model\": \"open-mixtral-8x22b\",\n",
+ " # Provide your Mistral AI API key here or put it into the MISTRAL_API_KEY environment variable.\n",
" \"api_key\": os.environ.get(\"MISTRAL_API_KEY\"),\n",
+ " # We specify the API Type as 'mistral' so it uses the Mistral AI client class\n",
+ " \"api_type\": \"mistral\",\n",
" }\n",
"]"
]
@@ -57,14 +171,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Two-Agent Coding Example\n",
- "\n",
- "In this example, we run a two-agent chat to count the number of prime numbers between 1 and 10,000 using coding."
+ "Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block."
]
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -73,27 +185,41 @@
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.coding import LocalCommandLineCodeExecutor\n",
"\n",
- "# Setting up the code executor.\n",
+ "# Setting up the code executor\n",
"workdir = Path(\"coding\")\n",
"workdir.mkdir(exist_ok=True)\n",
"code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n",
"\n",
- "# Setting up the agents.\n",
+ "# Setting up the agents\n",
+ "\n",
+ "# The UserProxyAgent will execute the code that the AssistantAgent provides\n",
"user_proxy_agent = UserProxyAgent(\n",
" name=\"User\",\n",
" code_execution_config={\"executor\": code_executor},\n",
- " is_termination_msg=lambda msg: \"TERMINATE\" in msg.get(\"content\"),\n",
+ " is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n",
")\n",
"\n",
+ "system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n",
+ "Solve tasks using your coding and language skills.\n",
+ "In the following cases, suggest python code (in a python coding block) for the user to execute.\n",
+ "Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n",
+ "When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n",
+ "Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n",
+ "If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n",
+ "When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n",
+ "IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n",
+ "\n",
+ "# The AssistantAgent, using Mistral AI's model, will take the coding request and return code\n",
"assistant_agent = AssistantAgent(\n",
" name=\"Mistral Assistant\",\n",
+ " system_message=system_message,\n",
" llm_config={\"config_list\": config_list},\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -102,16 +228,23 @@
"text": [
"\u001b[33mUser\u001b[0m (to Mistral Assistant):\n",
"\n",
- "Count how many prime numbers from 1 to 10000.\n",
+ "Provide code to count the number of prime numbers from 1 to 10000.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mMistral Assistant\u001b[0m (to User):\n",
"\n",
- "Sure, I can help with that. Here's a Python code snippet that counts the number of prime numbers from 1 to 10000.\n",
+ "To solve this task, I will write a Python function that checks if a number is prime or not. Then, I will iterate through the numbers from 1 to 10000 and count the prime numbers. Here's the plan:\n",
"\n",
- "```python\n",
- "# filename: prime_counter.py\n",
+ "1. Write a function `is_prime(n)` that checks if a number `n` is prime.\n",
+ "2. Initialize a variable `prime_count` to 0.\n",
+ "3. Iterate through numbers from 1 to 10000 using a for loop.\n",
+ "4. For each number, call the `is_prime(n)` function.\n",
+ "5. If the function returns True, increment the `prime_count` by 1.\n",
+ "6. Finally, print the `prime_count`.\n",
+ "\n",
+ "Here's the code that implements this plan:\n",
"\n",
+ "```python\n",
"def is_prime(n):\n",
" if n <= 1:\n",
" return False\n",
@@ -126,15 +259,37 @@
" i += 6\n",
" return True\n",
"\n",
- "count = 0\n",
+ "prime_count = 0\n",
"for num in range(1, 10001):\n",
" if is_prime(num):\n",
- " count += 1\n",
+ " prime_count += 1\n",
"\n",
- "print(count)\n",
+ "print(prime_count)\n",
"```\n",
"\n",
- "Please save this code in a file named `prime_counter.py` and run it. The output will be the count of prime numbers from 1 to 10000.\n",
+ "This code will count the number of prime numbers from 1 to 10000 and print the result. Please execute the code and let me know the output.\n",
+ "```python\n",
+ "def is_prime(n):\n",
+ " if n <= 1:\n",
+ " return False\n",
+ " if n <= 3:\n",
+ " return True\n",
+ " if n % 2 == 0 or n % 3 == 0:\n",
+ " return False\n",
+ " i = 5\n",
+ " while i * i <= n:\n",
+ " if n % i == 0 or n % (i + 2) == 0:\n",
+ " return False\n",
+ " i += 6\n",
+ " return True\n",
+ "\n",
+ "prime_count = 0\n",
+ "for num in range(1, 10001):\n",
+ " if is_prime(num):\n",
+ " prime_count += 1\n",
+ "\n",
+ "prime_count\n",
+ "```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -142,7 +297,7 @@
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[31m\n",
- ">>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...\u001b[0m\n",
+ ">>>>>>>> EXECUTING 2 CODE BLOCKS (inferred languages are [python, python])...\u001b[0m\n",
"\u001b[33mUser\u001b[0m (to Mistral Assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
@@ -152,17 +307,15 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mMistral Assistant\u001b[0m (to User):\n",
"\n",
- "Based on the output, the code I provided earlier has successfully executed and found that there are 1229 prime numbers between 1 and 10000. Here's how I approached this task step by step:\n",
- "\n",
- "1. I wrote a Python function `is_prime(n)` to check if a given number `n` is prime. This function returns `True` if `n` is prime and `False` otherwise.\n",
+ "The code executed successfully and the output is 1229.\n",
"\n",
- "2. I used a for loop to iterate through numbers from 1 to 10000, then called the `is_prime` function to determine if the current number is prime. If it is, I incremented a counter variable `count` by 1.\n",
+ "This means there are 1229 prime numbers from 1 to 10000.\n",
"\n",
- "3. I printed the value of `count` after the loop to display the total number of prime numbers in the given range.\n",
+ "The code defines a function `is_prime(n)` that checks if a number `n` is prime or not. Then, it initializes a variable `prime_count` to 0. It iterates through numbers from 1 to 10000 using a for loop. For each number, it calls the `is_prime(n)` function. If the function returns True, indicating that the number is prime, it increments the `prime_count` by 1. Finally, it prints the `prime_count`.\n",
"\n",
- "The output `1229` confirms that there are indeed 1229 prime numbers between 1 and 10000.\n",
+ "This code efficiently checks for prime numbers and accurately counts the number of prime numbers from 1 to 10000.\n",
"\n",
- "TERMINATE\n",
+ "FINISH\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -171,9 +324,10 @@
}
],
"source": [
+ "# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n",
"chat_result = user_proxy_agent.initiate_chat(\n",
" assistant_agent,\n",
- " message=\"Count how many prime numbers from 1 to 10000.\",\n",
+ " message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n",
")"
]
},
@@ -185,14 +339,47 @@
"\n",
"In this example, instead of writing code, we will have two agents playing chess against each other using tool calling to make moves.\n",
"\n",
- "First install the `chess` package by running the following command:"
+ "We'll change models to Mistral AI's large model for this challenging task."
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
+ "source": [
+ "config_list = [\n",
+ " {\n",
+ " # Let's choose the Mistral AI's largest model which is better at Chess than the Mixtral model\n",
+ " \"model\": \"mistral-large-latest\",\n",
+ " \"api_key\": os.environ.get(\"MISTRAL_API_KEY\"),\n",
+ " \"api_type\": \"mistral\",\n",
+ " }\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "First install the `chess` package by running the following command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Defaulting to user installation because normal site-packages is not writeable\n",
+ "Requirement already satisfied: chess in /home/autogen/.local/lib/python3.11/site-packages (1.10.0)\n"
+ ]
+ }
+ ],
"source": [
"! pip install chess"
]
@@ -206,7 +393,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -241,7 +428,7 @@
},
{
"cell_type": "code",
- "execution_count": 51,
+ "execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -277,9 +464,18 @@
},
{
"cell_type": "code",
- "execution_count": 52,
+ "execution_count": 19,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/autogen/autogen/autogen/agentchat/conversable_agent.py:2408: UserWarning: Function 'make_move' is being overridden.\n",
+ " warnings.warn(f\"Function '{name}' is being overridden.\", UserWarning)\n"
+ ]
+ }
+ ],
"source": [
"register_function(\n",
" make_move,\n",
@@ -311,7 +507,7 @@
},
{
"cell_type": "code",
- "execution_count": 53,
+ "execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -340,12 +536,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Start the chess game."
+ "Clear the board and start the chess game."
]
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -361,13 +557,7 @@
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
- "\u001b[34mStarting a new chat....\n",
- "\n",
- "Message:\n",
- "Let's play chess! Your move.\n",
- "\n",
- "Carryover: \n",
- "\u001b[0m\n",
+ "\u001b[34mStarting a new chat....\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
@@ -380,10 +570,10 @@
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
"\n",
- "\u001b[32m***** Suggested tool call (No tool call id found): make_move *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool call (AcS1aX4Rl): make_move *****\u001b[0m\n",
"Arguments: \n",
"{}\n",
- "\u001b[32m******************************************************************\u001b[0m\n",
+ "\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -400,17 +590,24 @@
". . . . . . . .\n",
". . . . . . . .\n",
". . . . . . . .\n",
- "P . . . . . . .\n",
- ". P P P P P P P\n",
- "R N B Q K B N R"
+ ". . . . . . . N\n",
+ "P P P P P P P P\n",
+ "R N B Q K B . R"
],
"text/plain": [
- "''"
+ "''"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[runtime logging] log_function_use: autogen logger is None\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -419,57 +616,35 @@
"\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
"\n",
- "\u001b[32m***** Response from calling tool (No id found) *****\u001b[0m\n",
- "a2a3\n",
- "\u001b[32m****************************************************\u001b[0m\n",
+ "\u001b[32m***** Response from calling tool (AcS1aX4Rl) *****\u001b[0m\n",
+ "g1h3\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
- "You made a move: a2a3. It's my turn now.\n",
- "\n",
- "e2e4\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's g1h3. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Player White):\n",
"\n",
- "You made a move: a2a3. It's my turn now.\n",
- "\n",
- "e2e4\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's g1h3. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
- "\u001b[34mStarting a new chat....\n",
- "\n",
- "Message:\n",
- "You made a move: a2a3. It's my turn now.\n",
- "\n",
- "e2e4\n",
- "\n",
- "Your move.\n",
- "\n",
- "Carryover: \n",
- "\u001b[0m\n",
+ "\u001b[34mStarting a new chat....\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player White):\n",
"\n",
- "You made a move: a2a3. It's my turn now.\n",
- "\n",
- "e2e4\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's g1h3. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -477,10 +652,10 @@
"\u001b[33mPlayer White\u001b[0m (to Board Proxy):\n",
"\n",
"\n",
- "\u001b[32m***** Suggested tool call (No tool call id found): make_move *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool call (tWVXVAujE): make_move *****\u001b[0m\n",
"Arguments: \n",
"{}\n",
- "\u001b[32m******************************************************************\u001b[0m\n",
+ "\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -493,21 +668,28 @@
"data": {
"image/svg+xml": [
""
+ ". . . . . . . N\n",
+ "P P P P P P P P\n",
+ "R N B Q K B . R"
],
"text/plain": [
- "''"
+ "''"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[runtime logging] log_function_use: autogen logger is None\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -516,49 +698,35 @@
"\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player White):\n",
"\n",
- "\u001b[32m***** Response from calling tool (No id found) *****\u001b[0m\n",
- "e7e5\n",
- "\u001b[32m****************************************************\u001b[0m\n",
+ "\u001b[32m***** Response from calling tool (tWVXVAujE) *****\u001b[0m\n",
+ "d7d5\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mPlayer White\u001b[0m (to Board Proxy):\n",
"\n",
- "I made a move: e7e5. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's d7d5. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[33mPlayer White\u001b[0m (to Player Black):\n",
"\n",
- "I made a move: e7e5. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's d7d5. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
- "\u001b[34mStarting a new chat....\n",
- "\n",
- "Message:\n",
- "I made a move: e7e5. It's your turn now.\n",
- "\n",
- "Your move.\n",
- "\n",
- "Carryover: \n",
- "\u001b[0m\n",
+ "\u001b[34mStarting a new chat....\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
"\n",
- "I made a move: e7e5. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's d7d5. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -566,10 +734,10 @@
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
"\n",
- "\u001b[32m***** Suggested tool call (No tool call id found): make_move *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool call (ZOfvRz0B1): make_move *****\u001b[0m\n",
"Arguments: \n",
"{}\n",
- "\u001b[32m******************************************************************\u001b[0m\n",
+ "\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -582,21 +750,28 @@
"data": {
"image/svg+xml": [
""
+ ". . . p . . . .\n",
+ ". . . . . . . .\n",
+ ". . N . . . . N\n",
+ "P P P P P P P P\n",
+ "R . B Q K B . R"
],
"text/plain": [
- "''"
+ "''"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[runtime logging] log_function_use: autogen logger is None\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -605,49 +780,35 @@
"\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
"\n",
- "\u001b[32m***** Response from calling tool (No id found) *****\u001b[0m\n",
- "h2h4\n",
- "\u001b[32m****************************************************\u001b[0m\n",
+ "\u001b[32m***** Response from calling tool (ZOfvRz0B1) *****\u001b[0m\n",
+ "b1c3\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
- "I made a move: h2h4. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's b1c3. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Player White):\n",
"\n",
- "I made a move: h2h4. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's b1c3. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
- "\u001b[34mStarting a new chat....\n",
- "\n",
- "Message:\n",
- "I made a move: h2h4. It's your turn now.\n",
- "\n",
- "Your move.\n",
- "\n",
- "Carryover: \n",
- "\u001b[0m\n",
+ "\u001b[34mStarting a new chat....\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player White):\n",
"\n",
- "I made a move: h2h4. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's b1c3. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -655,10 +816,10 @@
"\u001b[33mPlayer White\u001b[0m (to Board Proxy):\n",
"\n",
"\n",
- "\u001b[32m***** Suggested tool call (No tool call id found): make_move *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool call (LovRpi6Pq): make_move *****\u001b[0m\n",
"Arguments: \n",
"{}\n",
- "\u001b[32m******************************************************************\u001b[0m\n",
+ "\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -670,22 +831,29 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
- "''"
+ "''"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[runtime logging] log_function_use: autogen logger is None\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -694,49 +862,33 @@
"\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player White):\n",
"\n",
- "\u001b[32m***** Response from calling tool (No id found) *****\u001b[0m\n",
- "g8h6\n",
- "\u001b[32m****************************************************\u001b[0m\n",
+ "\u001b[32m***** Response from calling tool (LovRpi6Pq) *****\u001b[0m\n",
+ "c8g4\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mPlayer White\u001b[0m (to Board Proxy):\n",
"\n",
- "You moved g8h6. I made a move: g1g3. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's c8g4. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
- "\u001b[31m\n",
- ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[33mPlayer White\u001b[0m (to Player Black):\n",
"\n",
- "You moved g8h6. I made a move: g1g3. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's c8g4. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
- "\u001b[34mStarting a new chat....\n",
- "\n",
- "Message:\n",
- "You moved g8h6. I made a move: g1g3. It's your turn now.\n",
- "\n",
- "Your move.\n",
- "\n",
- "Carryover: \n",
- "\u001b[0m\n",
+ "\u001b[34mStarting a new chat....\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
"\n",
- "You moved g8h6. I made a move: g1g3. It's your turn now.\n",
- "\n",
- "Your move.\n",
+ "I have made my move. It's c8g4. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -744,10 +896,10 @@
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
"\n",
- "\u001b[32m***** Suggested tool call (No tool call id found): make_move *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool call (RfhC8brG7): make_move *****\u001b[0m\n",
"Arguments: \n",
"{}\n",
- "\u001b[32m******************************************************************\u001b[0m\n",
+ "\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -759,22 +911,29 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
- "''"
+ "''"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[runtime logging] log_function_use: autogen logger is None\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -783,49 +942,33 @@
"\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
"\n",
- "\u001b[32m***** Response from calling tool (No id found) *****\u001b[0m\n",
- "g1h3\n",
- "\u001b[32m****************************************************\u001b[0m\n",
+ "\u001b[32m***** Response from calling tool (RfhC8brG7) *****\u001b[0m\n",
+ "a1b1\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
- "You moved g8h6. I made a move: g1h3. You moved g1h3. It's my turn now.\n",
- "\n",
- "I made a move: d2d4. Your move.\n",
+ "I have made my move. It's a1b1. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
- "\u001b[31m\n",
- ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Player White):\n",
"\n",
- "You moved g8h6. I made a move: g1h3. You moved g1h3. It's my turn now.\n",
- "\n",
- "I made a move: d2d4. Your move.\n",
+ "I have made my move. It's a1b1. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
- "\u001b[34mStarting a new chat....\n",
- "\n",
- "Message:\n",
- "You moved g8h6. I made a move: g1h3. You moved g1h3. It's my turn now.\n",
- "\n",
- "I made a move: d2d4. Your move.\n",
- "\n",
- "Carryover: \n",
- "\u001b[0m\n",
+ "\u001b[34mStarting a new chat....\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player White):\n",
"\n",
- "You moved g8h6. I made a move: g1h3. You moved g1h3. It's my turn now.\n",
- "\n",
- "I made a move: d2d4. Your move.\n",
+ "I have made my move. It's a1b1. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -833,10 +976,10 @@
"\u001b[33mPlayer White\u001b[0m (to Board Proxy):\n",
"\n",
"\n",
- "\u001b[32m***** Suggested tool call (No tool call id found): make_move *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool call (6aVW1t0lm): make_move *****\u001b[0m\n",
"Arguments: \n",
"{}\n",
- "\u001b[32m******************************************************************\u001b[0m\n",
+ "\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -848,22 +991,29 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
- "''"
+ "''"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[runtime logging] log_function_use: autogen logger is None\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -872,41 +1022,33 @@
"\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player White):\n",
"\n",
- "\u001b[32m***** Response from calling tool (No id found) *****\u001b[0m\n",
- "d8h4\n",
- "\u001b[32m****************************************************\u001b[0m\n",
+ "\u001b[32m***** Response from calling tool (6aVW1t0lm) *****\u001b[0m\n",
+ "a7a6\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mPlayer White\u001b[0m (to Board Proxy):\n",
"\n",
- "You moved d8h4. I made a move: d4d5. Your move.\n",
+ "I have made my move. It's a7a6. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
- "\u001b[31m\n",
- ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[33mPlayer White\u001b[0m (to Player Black):\n",
"\n",
- "You moved d8h4. I made a move: d4d5. Your move.\n",
+ "I have made my move. It's a7a6. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
- "\u001b[34mStarting a new chat....\n",
- "\n",
- "Message:\n",
- "You moved d8h4. I made a move: d4d5. Your move.\n",
- "\n",
- "Carryover: \n",
- "\u001b[0m\n",
+ "\u001b[34mStarting a new chat....\u001b[0m\n",
"\u001b[34m\n",
"********************************************************************************\u001b[0m\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
"\n",
- "You moved d8h4. I made a move: d4d5. Your move.\n",
+ "I have made my move. It's a7a6. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -914,10 +1056,10 @@
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
"\n",
- "\u001b[32m***** Suggested tool call (No tool call id found): make_move *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool call (kPTEInlLR): make_move *****\u001b[0m\n",
"Arguments: \n",
"{}\n",
- "\u001b[32m******************************************************************\u001b[0m\n",
+ "\u001b[32m******************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
@@ -929,22 +1071,29 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
- "''"
+ "''"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[runtime logging] log_function_use: autogen logger is None\n"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
@@ -953,23 +1102,21 @@
"\n",
"\u001b[33mBoard Proxy\u001b[0m (to Player Black):\n",
"\n",
- "\u001b[32m***** Response from calling tool (No id found) *****\u001b[0m\n",
- "e2e4\n",
- "\u001b[32m****************************************************\u001b[0m\n",
+ "\u001b[32m***** Response from calling tool (kPTEInlLR) *****\u001b[0m\n",
+ "h3f4\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Board Proxy):\n",
"\n",
- "You made a move: e2e4. I made a move: d5e4. Your move.\n",
+ "I have made my move. It's h3f4. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n",
- "\u001b[31m\n",
- ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n",
"\u001b[33mPlayer Black\u001b[0m (to Player White):\n",
"\n",
- "You made a move: e2e4. I made a move: d5e4. Your move.\n",
+ "I have made my move. It's h3f4. Your turn.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@@ -1003,7 +1150,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.5"
+ "version": "3.11.9"
}
},
"nbformat": 4,