From c82650c5d27a6ad4554c3fd2575a689dbb66de83 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 1 Jul 2024 20:04:32 +0000 Subject: [PATCH 01/14] Ollama client! With function calling. Initial commit, client, no docs or tests yet. --- .github/workflows/contrib-tests.yml | 40 +++ autogen/logger/file_logger.py | 12 +- autogen/logger/sqlite_logger.py | 12 +- autogen/oai/client.py | 12 + autogen/oai/ollama.py | 500 ++++++++++++++++++++++++++++ autogen/runtime_logging.py | 5 +- setup.py | 1 + test/oai/test_ollama.py | 14 + 8 files changed, 593 insertions(+), 3 deletions(-) create mode 100644 autogen/oai/ollama.py create mode 100644 test/oai/test_ollama.py diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 895e810022d..31ea16fabd4 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -638,3 +638,43 @@ jobs: with: file: ./coverage.xml flags: unittests + + OllamaTest: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2019] + python-version: ["3.9", "3.10", "3.11", "3.12"] + exclude: + - os: macos-latest + python-version: "3.9" + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest-cov>=5 + - name: Install packages and dependencies for Ollama + run: | + pip install -e .[ollama,test] + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pytest test/oai/test_ollama.py --skip-openai + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index cdebbdc0eb7..923eeac78be 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -21,6 +21,7 @@ from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient from autogen.oai.mistral import MistralAIClient + from autogen.oai.ollama import OllamaClient from autogen.oai.together import TogetherClient logger = logging.getLogger(__name__) @@ -205,7 +206,16 @@ def log_new_wrapper( def log_new_client( self, - client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient, + client: ( + AzureOpenAI + | OpenAI + | GeminiClient + | AnthropicClient + | MistralAIClient + | TogetherClient + | GroqClient + | OllamaClient + ), wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index ccde6bd1d81..1d5dbbe1616 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -22,6 +22,7 @@ from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient from autogen.oai.mistral import MistralAIClient + from autogen.oai.ollama import OllamaClient from autogen.oai.together import TogetherClient logger = logging.getLogger(__name__) @@ -392,7 +393,16 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st def log_new_client( self, - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient], + client: Union[ + AzureOpenAI, + OpenAI, + GeminiClient, + AnthropicClient, + MistralAIClient, + TogetherClient, + GroqClient, + OllamaClient, + ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/autogen/oai/client.py b/autogen/oai/client.py index a7b12ce83da..263a7db3b35 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -77,6 +77,13 @@ except ImportError as e: groq_import_exception = e +try: + from autogen.oai.ollama import OllamaClient + + ollama_import_exception: Optional[ImportError] = None +except ImportError as e: + ollama_import_exception = e + logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -497,6 +504,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s raise ImportError("Please install `groq` to use the Groq API.") client = GroqClient(**openai_config) self._clients.append(client) + elif api_type is not None and api_type.startswith("ollama"): + if ollama_import_exception: + raise ImportError("Please install `ollama` to use the Ollama API.") + client = OllamaClient(**openai_config) + self._clients.append(client) else: client = OpenAI(**openai_config) self._clients.append(OpenAIClient(client)) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py new file mode 100644 index 00000000000..0acdc6b8369 --- /dev/null +++ b/autogen/oai/ollama.py @@ -0,0 +1,500 @@ +"""Create an OpenAI-compatible client using Ollama's API. + +Example: + llm_config={ + "config_list": [{ + "api_type": "ollama", + "model": "mistral:7b-instruct-v0.3-q6_K" + } + ]} + + agent = autogen.AssistantAgent("my_agent", llm_config=llm_config) + +Install Ollama's python library using: pip install --upgrade ollama + +Resources: +- https://github.com/ollama/ollama-python +""" + +from __future__ import annotations + +import copy +import json +import random +import re +import time +from typing import Any, Dict, List, Tuple + +import ollama +from ollama import Client +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import ChatCompletionMessage, Choice +from openai.types.completion_usage import CompletionUsage + +from autogen.oai.client_utils import should_hide_tools, validate_parameter + + +class OllamaClient: + """Client for Ollama's API.""" + + # Defaults for manual tool calling + # Instruction is added to the first system message and provides directions to follow a two step + # process + # 1. (before tools have been called) Return JSON with the functions to call + # 2. (directly after tools have been called) Return Text describing the results of the function calls in text format + + # Override using "manual_tool_call_instruction" config parameter + TOOL_CALL_MANUAL_INSTRUCTION = ( + "You are to follow a strict two step process that will occur over " + "a number of interactions, so pay attention to what step you are in based on the full " + "conversation. We will be taking turns so only do one step at a time so don't perform step " + "2 until step 1 is complete and I've told you the result. The first step is to choose one " + "or more functions based on the request given and return only JSON with the functions and " + "arguments to use. The second step is to analyse the given output of the function and summarise " + "it returning only TEXT and not Python or JSON. " + "In terms of your response format, for step 1 return only JSON and NO OTHER text, " + "for step 2 return only text and NO JSON/Python/Markdown. " + 'The format for running a function is [{"name": "function_name1", "arguments":{"argument_name": "argument_value"}},{"name": "function_name2", "arguments":{"argument_name": "argument_value"}}]\n' + "The following functions are available to you:[FUNCTIONS_LIST]" + ) + + # Appended to the last user message if no tools have been called + # Override using "manual_tool_call_step1" config parameter + TOOL_CALL_MANUAL_STEP1 = " (proceed with step 1)" + + # Appended to the user message after tools have been executed. Will create a 'user' message if one doesn't exist. + # Override using "manual_tool_call_step2" config parameter + TOOL_CALL_MANUAL_STEP2 = " (proceed with step 2)" + + def __init__(self, **kwargs): + """Note that no api_key or environment variable is required for Ollama. + + Args: + None + """ + + def message_retrieval(self, response) -> List: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + return [choice.message for choice in response.choices] + + def cost(self, response) -> float: + return response.cost + + @staticmethod + def get_usage(response) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + # ... # pragma: no cover + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Loads the parameters for Ollama API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" + ollama_params = {} + + # Check that we have what we need to use Ollama's API + # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion + + # The main parameters are model, prompt, stream, and options + # Options is a dictionary of parameters for the model + # There are other, advanced, parameters such as format, system (to override system message), template, raw, etc. - not used + + # We won't enforce the available models + ollama_params["model"] = params.get("model", None) + assert ollama_params[ + "model" + ], "Please specify the 'model' in your config list entry to nominate the Ollama model to use. The model must start with 'ollama/' or 'ollama_chat/'." + + ollama_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None) + + # Build up the options dictionary + # https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + options_dict = {} + + if "num_predict" in params: + # Maximum number of tokens to predict, note: -1 is infinite, -2 is fill context, 128 is default + ollama_params["num_predict"] = validate_parameter(params, "num_predict", int, False, 128, None, None) + + if "repeat_penalty" in params: + options_dict["repeat_penalty"] = validate_parameter( + params, "repeat_penalty", (int, float), False, 1.1, None, None + ) + + if "seed" in params: + options_dict["seed"] = validate_parameter(params, "seed", int, False, 42, None, None) + + if "temperature" in params: + ollama_params["temperature"] = validate_parameter( + params, "temperature", (int, float), False, 0.8, None, None + ) + + if "top_k" in params: + ollama_params["top_k"] = validate_parameter(params, "top_k", int, False, 40, None, None) + + if "top_p" in params: + ollama_params["top_p"] = validate_parameter(params, "top_p", (int, float), False, 0.9, None, None) + + if len(options_dict) != 0: + ollama_params["options"] = options_dict + + return ollama_params + + def create(self, params: Dict) -> ChatCompletion: + + messages = params.get("messages", []) + + # Are tools involved in this conversation? + self._tools_in_conversation = "tools" in params + + # Function/Tool calling options + # For the time-being Ollama does not support tool calling, so we will handle this + # manually by providing guidance to the LLM and parsing responses to look for tool calls + # This variable could be omitted but I think it is useful to keep in for now. + self._tool_calling_mode = "manual" + + if self._tool_calling_mode == "manual": + # Load defaults + self._manual_tool_call_instruction = validate_parameter( + params, "manual_tool_call_instruction", str, False, self.TOOL_CALL_MANUAL_INSTRUCTION, None, None + ) + self._manual_tool_call_step1 = validate_parameter( + params, "manual_tool_call_step1", str, False, self.TOOL_CALL_MANUAL_STEP1, None, None + ) + self._manual_tool_call_step2 = validate_parameter( + params, "manual_tool_call_step2", str, False, self.TOOL_CALL_MANUAL_STEP2, None, None + ) + + # Convert AutoGen messages to Ollama messages + ollama_messages = self.oai_messages_to_ollama_messages( + messages, params["tools"] if self._tools_in_conversation else None + ) + + # Parse parameters to the Ollama API's parameters + ollama_params = self.parse_params(params) + + # Add tools to the call if we have them and aren't hiding them + if self._tools_in_conversation: + # For Ollama we will inject the available tools into the prompt + ollama_params["format"] = "" # Don't force JSON for manual tool calling mode + + ollama_params["messages"] = ollama_messages + + # Token counts will be returned + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + ans = None + try: + if "client_host" in params: + client = Client(host=params["client_host"]) + response = client.chat(**ollama_params) + else: + response = ollama.chat(**ollama_params) + except Exception as e: + raise RuntimeError(f"Ollama exception occurred: {e}") + else: + + if ollama_params["stream"]: + # Read in the chunks as they stream, taking in tool_calls which may be across + # multiple chunks if more than one suggested + ans = "" + for chunk in response: + ans = ans + (chunk["message"]["content"] or "") + + if "done_reason" in chunk: + prompt_tokens = chunk["prompt_eval_count"] + completion_tokens = chunk["eval_count"] + total_tokens = prompt_tokens + completion_tokens + else: + # Non-streaming finished + ans: str = response["message"]["content"] + + prompt_tokens = response["prompt_eval_count"] + completion_tokens = response["eval_count"] + total_tokens = prompt_tokens + completion_tokens + + if response is not None: + + if ollama_params["stream"]: + response_content = ans + response_id = chunk["created_at"] + else: + response_id = response["created_at"] + + # Are we doing a manual tool call + is_manual_tool_calling = False + + if self._tools_in_conversation and self._tool_calling_mode == "manual": + # Try to convert the response to a tool call object + response_toolcalls = response_to_tool_call(ans) + + # If we can, then it's a manual tool call + if response_toolcalls is not None: + ollama_finish = "tool_calls" + tool_calls = [] + random_id = random.randint(0, 10000) + + for json_function in response_toolcalls: + tool_calls.append( + ChatCompletionMessageToolCall( + id="ollama_func_{}".format(random_id), + function={ + "name": json_function["name"], + "arguments": ( + json.dumps(json_function["arguments"]) if "arguments" in json_function else "{}" + ), + }, + type="function", + ) + ) + + random_id += 1 + + is_manual_tool_calling = True + + # Blank the message content + response_content = "" + + if not is_manual_tool_calling: + if not ollama_params["stream"]: + response_content = response["message"]["content"] + ollama_finish = "stop" + tool_calls = None + else: + raise RuntimeError("Failed to get response from Ollama after retrying 5 times.") + + # 3. convert output + message = ChatCompletionMessage( + role="assistant", + content=response_content, + function_call=None, + tool_calls=tool_calls, + ) + choices = [Choice(finish_reason=ollama_finish, index=0, message=message)] + + response_oai = ChatCompletion( + id=response_id, + model=ollama_params["model"], + created=int(time.time()), + object="chat.completion", + choices=choices, + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + cost=0, # Local models, FREE! + ) + + return response_oai + + def oai_messages_to_ollama_messages(self, messages: list[Dict[str, Any]], tools: list) -> list[dict[str, Any]]: + """Convert messages from OAI format to Ollama's format. + We correct for any specific role orders and types, and convert tools to messages (as Ollama can't use tool messages) + """ + + #### MS UP TO HERE - CONVERT TOOL MESSAGES TO STANDARD MESSAGES #### + + ollama_messages = copy.deepcopy(messages) + + # Remove the name field + for message in ollama_messages: + if "name" in message: + message.pop("name", None) + + # Having a 'system' message on the end does not work well with Ollama, so we change it to 'user' + # 'system' messages on the end are typical of the summarisation message: summary_method="reflection_with_llm" + if len(ollama_messages) > 1 and ollama_messages[-1]["role"] == "system": + ollama_messages[-1]["role"] = "user" + + # Process messages for tool calling manually + if self._tools_in_conversation and self._tool_calling_mode == "manual": + # 1. We need to append instructions to the starting system message on function calling + # 2. If we have not yet called tools we append "step 1 instruction" to the latest user message + # 3. If we have already called tools we append "step 2 instruction" to the latest user message + + have_tool_calls = False + have_tool_results = False + last_tool_result_index = -1 + + for i, message in enumerate(ollama_messages): + if "tool_calls" in message: + have_tool_calls = True + if "tool_call_id" in message: + have_tool_results = True + last_tool_result_index = i + + tool_result_is_last_msg = have_tool_results and last_tool_result_index == len(ollama_messages) - 1 + + if ollama_messages[0]["role"] == "system": + manual_instruction = self._manual_tool_call_instruction + + # Build a string of the functions available + functions_string = "" + for function in tools: + functions_string += f"""\n{function}\n""" + + # Replace single quotes with double questions - Not sure why this helps the LLM perform + # better, but it seems to. Monitor and remove if not necessary. + functions_string = functions_string.replace("'", '"') + + manual_instruction = manual_instruction.replace("[FUNCTIONS_LIST]", functions_string) + + # Update the system message with the instructions and functions + ollama_messages[0]["content"] = ollama_messages[0]["content"] + manual_instruction.rstrip() + + # If we are still in the function calling or evaluating process, append the steps instruction + if not have_tool_calls or tool_result_is_last_msg: + if ollama_messages[0]["role"] == "system": + # NOTE: we require a system message to exist for the manual steps texts + # Append the manual step instructions + content_to_append = ( + self._manual_tool_call_step1 if not have_tool_results else self._manual_tool_call_step2 + ) + + if content_to_append != "": + # Append the relevant tool call instruction to the latest user message + if ollama_messages[-1]["role"] == "user": + ollama_messages[-1]["content"] = ollama_messages[-1]["content"] + content_to_append + else: + ollama_messages.append({"role": "user", "content": content_to_append}) + + # Convert tool call and tool result messages to normal text messages for Ollama + for i, message in enumerate(ollama_messages): + if "tool_calls" in message: + # Recommended tool calls + content = "Run the following function(s):" + for tool_call in message["tool_calls"]: + content = content + "\n" + str(tool_call) + ollama_messages[i] = {"role": "assistant", "content": content} + if "tool_call_id" in message: + # Executed tool results + message["result"] = message["content"] + del message["content"] + content = "The following function was run: " + str(message) + ollama_messages[i] = {"role": "user", "content": content} + + # As we are changing messages, let's merge if they have two user messages on the end and the last one is tool call step instructions + if ( + len(ollama_messages) >= 2 + and ollama_messages[-2]["role"] == "user" + and ollama_messages[-1]["role"] == "user" + and ( + ollama_messages[-1]["content"] == self._manual_tool_call_step1 + or ollama_messages[-1]["content"] == self._manual_tool_call_step2 + ) + ): + ollama_messages[-2]["content"] = ollama_messages[-2]["content"] + ollama_messages[-1]["content"] + del ollama_messages[-1] + + # Ensure the last message is a user / system message, if not, add a user message + if ollama_messages[-1]["role"] != "user" and ollama_messages[-1]["role"] != "system": + ollama_messages.append({"role": "user", "content": "Please continue."}) + + return ollama_messages + + +def response_to_tool_call(response_string: str) -> Any: + """Attempts to convert the response to an object, aimed to align with function format [{},{}]""" + + # We try and detect the list[dict] format: + pattern = r"\[[\s\S]*?\]" + + # Search for the pattern in the input string + matches = re.findall(pattern, response_string.strip()) + + for match in matches: + + # It has matched, extract it and load it + json_str = match.strip() + try: + data_object = json.loads(json_str) + + data_object = _object_to_tool_call(data_object) + + if data_object is not None: + return data_object + except Exception: + pass + + # Couldn't parse with Regular Expression, try as eval + try: + data_object = eval(response_string.strip()) + + data_object = _object_to_tool_call(data_object) + + if data_object is not None: + return data_object + except Exception: + pass + + # There's no tool call in the response + return None + + +def _object_to_tool_call(data_object: Any) -> List[Dict]: + """Attempts to convert an object to a valid tool call object List[Dict] and returns it, if it can, otherwise None""" + + # If it's a dictionary and not a list then wrap in a list + if isinstance(data_object, dict): + data_object = [data_object] + + # Validate that the data is a list of dictionaries + if isinstance(data_object, list) and all(isinstance(item, dict) for item in data_object): + # Perfect format, a list of dictionaries + + # Check that each dictionary has at least 'name', optionally 'arguments' and no other keys + is_invalid = False + for item in data_object: + if not is_valid_tool_call_item(item): + is_invalid = True + break + + # All passed, name and (optionally) arguments exist for all entries. + if not is_invalid: + return data_object + elif isinstance(data_object, list): + # If it's a list but the items are not dictionaries, check if they are strings that can be converted to dictionaries + data_copy = data_object.copy() + is_invalid = False + for i, item in enumerate(data_copy): + try: + new_item = eval(item) + if isinstance(new_item, dict): + if is_valid_tool_call_item(new_item): + data_object[i] = new_item + else: + is_invalid = True + break + else: + is_invalid = True + break + except Exception: + is_invalid = True + break + + if not is_invalid: + return data_object + + return None + + +def is_valid_tool_call_item(call_item: dict) -> bool: + """Check that a dictionary item has at least 'name', optionally 'arguments' and no other keys to match a tool call JSON""" + if "name" not in call_item and not isinstance(call_item["name"], str): + return False + + if set(call_item.keys()) - {"name", "arguments"}: + return False + + return True diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index 4ad76cf5b7d..8cccd312009 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -17,6 +17,7 @@ from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient from autogen.oai.mistral import MistralAIClient + from autogen.oai.ollama import OllamaClient from autogen.oai.together import TogetherClient logger = logging.getLogger(__name__) @@ -111,7 +112,9 @@ def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig def log_new_client( - client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient], + client: Union[ + AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient, OllamaClient + ], wrapper: OpenAIWrapper, init_args: Dict[str, Any], ) -> None: diff --git a/setup.py b/setup.py index 9a67c70f49d..1cfdfdceed0 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,7 @@ "anthropic": ["anthropic>=0.23.1"], "mistral": ["mistralai>=0.2.0"], "groq": ["groq>=0.9.0"], + "ollama": ["ollama>=0.2.1"], } setuptools.setup( diff --git a/test/oai/test_ollama.py b/test/oai/test_ollama.py new file mode 100644 index 00000000000..edb6ba041a2 --- /dev/null +++ b/test/oai/test_ollama.py @@ -0,0 +1,14 @@ +from unittest.mock import MagicMock, patch + +import pytest + +try: + from autogen.oai.ollama import OllamaClient + + skip = False +except ImportError: + OllamaClient = object + InternalServerError = object + skip = True + +# TODO From 42cfe2156ad2809a56106a9c07010398d5e5355e Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 1 Jul 2024 20:37:44 +0000 Subject: [PATCH 02/14] Tidy comments --- autogen/oai/ollama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 0acdc6b8369..6e284b0852c 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -303,8 +303,6 @@ def oai_messages_to_ollama_messages(self, messages: list[Dict[str, Any]], tools: We correct for any specific role orders and types, and convert tools to messages (as Ollama can't use tool messages) """ - #### MS UP TO HERE - CONVERT TOOL MESSAGES TO STANDARD MESSAGES #### - ollama_messages = copy.deepcopy(messages) # Remove the name field From 292a16795e2aa6cf4a7b8c3f04313ac56db2c531 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 2 Jul 2024 02:21:40 +0000 Subject: [PATCH 03/14] Cater for missing prompt token count --- autogen/oai/ollama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 6e284b0852c..94248a9516b 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -212,15 +212,15 @@ def create(self, params: Dict) -> ChatCompletion: ans = ans + (chunk["message"]["content"] or "") if "done_reason" in chunk: - prompt_tokens = chunk["prompt_eval_count"] - completion_tokens = chunk["eval_count"] + prompt_tokens = chunk["prompt_eval_count"] if "prompt_eval_count" in chunk else 0 + completion_tokens = chunk["eval_count"] if "eval_count" in chunk else 0 total_tokens = prompt_tokens + completion_tokens else: # Non-streaming finished ans: str = response["message"]["content"] - prompt_tokens = response["prompt_eval_count"] - completion_tokens = response["eval_count"] + prompt_tokens = response["prompt_eval_count"] if "prompt_eval_count" in response else 0 + completion_tokens = response["eval_count"] if "eval_count" in response else 0 total_tokens = prompt_tokens + completion_tokens if response is not None: From 4c6eb1e54f1e28fdc42a629a5755e92a50a33282 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 2 Jul 2024 09:07:22 +0000 Subject: [PATCH 04/14] Removed use of eval, added json parsing support library --- autogen/oai/ollama.py | 25 ++++++++++++------------- setup.py | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 94248a9516b..7871e9949c7 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -26,6 +26,7 @@ from typing import Any, Dict, List, Tuple import ollama +from fix_busted_json import repair_json from ollama import Client from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice @@ -415,26 +416,24 @@ def response_to_tool_call(response_string: str) -> Any: # It has matched, extract it and load it json_str = match.strip() + data_object = None + try: + # Attempt to convert it as is data_object = json.loads(json_str) + except Exception: + try: + # If that fails, attempt to repair it + fixed_json = repair_json(json_str) + data_object = json.loads(fixed_json) + except Exception: + pass + if data_object is not None: data_object = _object_to_tool_call(data_object) if data_object is not None: return data_object - except Exception: - pass - - # Couldn't parse with Regular Expression, try as eval - try: - data_object = eval(response_string.strip()) - - data_object = _object_to_tool_call(data_object) - - if data_object is not None: - return data_object - except Exception: - pass # There's no tool call in the response return None diff --git a/setup.py b/setup.py index 1cfdfdceed0..8e65374a329 100644 --- a/setup.py +++ b/setup.py @@ -92,7 +92,7 @@ "anthropic": ["anthropic>=0.23.1"], "mistral": ["mistralai>=0.2.0"], "groq": ["groq>=0.9.0"], - "ollama": ["ollama>=0.2.1"], + "ollama": ["ollama>=0.2.1", "fix_busted_json"], } setuptools.setup( From c75155bedc4ef6f1345622b684db46bd91c22df5 Mon Sep 17 00:00:00 2001 From: Mark Sze <66362098+marklysze@users.noreply.github.com> Date: Wed, 3 Jul 2024 06:10:18 +1000 Subject: [PATCH 05/14] Fix to the use of the JSON fix library, handling of Mixtral escape sequence --- autogen/oai/ollama.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 7871e9949c7..880c675f462 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -424,8 +424,20 @@ def response_to_tool_call(response_string: str) -> Any: except Exception: try: # If that fails, attempt to repair it - fixed_json = repair_json(json_str) + + # Enclose to a JSON object for repairing, which is restored upon fix + fixed_json = repair_json("{'temp':" + json_str + "}") data_object = json.loads(fixed_json) + data_object = data_object["temp"] + except json.JSONDecodeError as e: + if e.msg == "Invalid \\escape": + # Handle Mistral/Mixtral trying to escape underlines with \\ + try: + fixed_json = repair_json("{'temp':" + json_str.replace("\\_", "_") + "}") + data_object = json.loads(fixed_json) + data_object = data_object["temp"] + except Exception: + pass except Exception: pass From bdfc9b154b93ca8f086de0aec5b3de8c50d632a0 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Wed, 3 Jul 2024 09:32:48 +0000 Subject: [PATCH 06/14] Fixed 'name' in JSON bug, catered for single function call JSON without [] --- autogen/oai/ollama.py | 78 +++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 880c675f462..052e6bcf653 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -55,7 +55,9 @@ class OllamaClient: "it returning only TEXT and not Python or JSON. " "In terms of your response format, for step 1 return only JSON and NO OTHER text, " "for step 2 return only text and NO JSON/Python/Markdown. " - 'The format for running a function is [{"name": "function_name1", "arguments":{"argument_name": "argument_value"}},{"name": "function_name2", "arguments":{"argument_name": "argument_value"}}]\n' + 'The format for running a function is [{"name": "function_name1", "arguments":{"argument_name": "argument_value"}},{"name": "function_name2", "arguments":{"argument_name": "argument_value"}}] ' + 'Make sure the keys "name" and "arguments" are as described. ' + "If you don't get the format correct, try again. " "The following functions are available to you:[FUNCTIONS_LIST]" ) @@ -407,45 +409,57 @@ def response_to_tool_call(response_string: str) -> Any: """Attempts to convert the response to an object, aimed to align with function format [{},{}]""" # We try and detect the list[dict] format: - pattern = r"\[[\s\S]*?\]" + # Pattern 1 is [{},{}] + # Pattern 2 is {} (without the [], so could be a single function call) + patterns = [r"\[[\s\S]*?\]", r"\{[\s\S]*\}"] - # Search for the pattern in the input string - matches = re.findall(pattern, response_string.strip()) + for i, pattern in enumerate(patterns): + # Search for the pattern in the input string + matches = re.findall(pattern, response_string.strip()) - for match in matches: + for match in matches: - # It has matched, extract it and load it - json_str = match.strip() - data_object = None + # It has matched, extract it and load it + json_str = match.strip() + data_object = None - try: - # Attempt to convert it as is - data_object = json.loads(json_str) - except Exception: try: - # If that fails, attempt to repair it - - # Enclose to a JSON object for repairing, which is restored upon fix - fixed_json = repair_json("{'temp':" + json_str + "}") - data_object = json.loads(fixed_json) - data_object = data_object["temp"] - except json.JSONDecodeError as e: - if e.msg == "Invalid \\escape": - # Handle Mistral/Mixtral trying to escape underlines with \\ - try: - fixed_json = repair_json("{'temp':" + json_str.replace("\\_", "_") + "}") - data_object = json.loads(fixed_json) - data_object = data_object["temp"] - except Exception: - pass + # Attempt to convert it as is + data_object = json.loads(json_str) except Exception: - pass + try: + # If that fails, attempt to repair it - if data_object is not None: - data_object = _object_to_tool_call(data_object) + if i == 0: + # Enclose to a JSON object for repairing, which is restored upon fix + fixed_json = repair_json("{'temp':" + json_str + "}") + data_object = json.loads(fixed_json) + data_object = data_object["temp"] + else: + fixed_json = repair_json(json_str) + data_object = json.loads(fixed_json) + except json.JSONDecodeError as e: + if e.msg == "Invalid \\escape": + # Handle Mistral/Mixtral trying to escape underlines with \\ + try: + json_str = json_str.replace("\\_", "_") + if i == 0: + fixed_json = repair_json("{'temp':" + json_str + "}") + data_object = json.loads(fixed_json) + data_object = data_object["temp"] + else: + fixed_json = repair_json("{'temp':" + json_str + "}") + data_object = json.loads(fixed_json) + except Exception: + pass + except Exception: + pass if data_object is not None: - return data_object + data_object = _object_to_tool_call(data_object) + + if data_object is not None: + return data_object # There's no tool call in the response return None @@ -500,7 +514,7 @@ def _object_to_tool_call(data_object: Any) -> List[Dict]: def is_valid_tool_call_item(call_item: dict) -> bool: """Check that a dictionary item has at least 'name', optionally 'arguments' and no other keys to match a tool call JSON""" - if "name" not in call_item and not isinstance(call_item["name"], str): + if "name" not in call_item or not isinstance(call_item["name"], str): return False if set(call_item.keys()) - {"name", "arguments"}: From ec124c1381f2b5a5af3850d9e9737e332a8d6e26 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Wed, 3 Jul 2024 15:10:45 +0000 Subject: [PATCH 07/14] removing role='tool' from inner tool result to reduce token usage. --- autogen/oai/ollama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 052e6bcf653..a9155b66835 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -382,6 +382,7 @@ def oai_messages_to_ollama_messages(self, messages: list[Dict[str, Any]], tools: # Executed tool results message["result"] = message["content"] del message["content"] + del message["role"] content = "The following function was run: " + str(message) ollama_messages[i] = {"role": "user", "content": content} From 513b59e0ef5164712d3d5de12edb2ac595dff6d7 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Fri, 26 Jul 2024 01:08:36 +0000 Subject: [PATCH 08/14] Added Ollama documentation and updated library versions --- setup.py | 2 +- .../non-openai-models/local-ollama.ipynb | 551 ++++++++++++++++++ 2 files changed, 552 insertions(+), 1 deletion(-) create mode 100644 website/docs/topics/non-openai-models/local-ollama.ipynb diff --git a/setup.py b/setup.py index 71ee72aec78..c26a343fdc5 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,7 @@ "mistral": ["mistralai>=0.2.0"], "groq": ["groq>=0.9.0"], "cohere": ["cohere>=5.5.8"], - "ollama": ["ollama>=0.2.1", "fix_busted_json"], + "ollama": ["ollama>=0.3.0", "fix_busted_json>=0.0.18"], } setuptools.setup( diff --git a/website/docs/topics/non-openai-models/local-ollama.ipynb b/website/docs/topics/non-openai-models/local-ollama.ipynb new file mode 100644 index 00000000000..7c1e77df6a6 --- /dev/null +++ b/website/docs/topics/non-openai-models/local-ollama.ipynb @@ -0,0 +1,551 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ollama\n", + "\n", + "[Ollama](https://ollama.com/) is a local inference engine that enables you to run open-weight LLMs in your environment. It has native support for a large number of models such as Google's Gemma, Meta's Llama 2/3/3.1, Microsoft's Phi 3, Mistral.AI's Mistral/Mixtral, and Cohere's Command R models.\n", + "\n", + "Note: Previously, to use Ollama with AutoGen you required LiteLLM. Now it can be used directly and supports tool calling." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Features\n", + "\n", + "When using this Ollama client class, messages are tailored to accommodate the specific requirements of Ollama's API and this includes message role sequences, support for function/tool calling, and token usage.\n", + "\n", + "## Installing Ollama\n", + "\n", + "For Mac and Windows, [download Ollama](https://ollama.com/download).\n", + "\n", + "For Linux:\n", + "\n", + "```bash\n", + "curl -fsSL https://ollama.com/install.sh | sh\n", + "```\n", + "\n", + "## Downloading models for Ollama\n", + "\n", + "Ollama has a library of models to choose from, see them [here](https://ollama.com/library).\n", + "\n", + "Before you can use a model, you need to download it (using the name of the model from the library):\n", + "\n", + "```bash\n", + "ollama pull llama3.1\n", + "```\n", + "\n", + "To view the models you have downloaded and can use:\n", + "\n", + "```bash\n", + "ollama list\n", + "```\n", + "\n", + "## Getting started with AutoGen and Ollama\n", + "\n", + "When installing AutoGen, you need to install the `pyautogen` package with the Ollama library.\n", + "\n", + "``` bash\n", + "pip install pyautogen[ollama]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See the sample `OAI_CONFIG_LIST` below showing how the Ollama client class is used by specifying the `api_type` as `ollama`.\n", + "\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1\",\n", + " \"api_type\": \"ollama\"\n", + " },\n", + " {\n", + " \"model\": \"llama3.1:8b-instruct-q6_K\",\n", + " \"api_type\": \"ollama\"\n", + " },\n", + " {\n", + " \"model\": \"mistral-nemo\",\n", + " \"api_type\": \"ollama\"\n", + " }\n", + "]\n", + "```\n", + "\n", + "If you need to specify the URL for your Ollama install, use the `client_host` key in your config as per the below example:\n", + "\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1\",\n", + " \"api_type\": \"ollama\",\n", + " \"client_host\": \"http://192.168.0.1:11434\"\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API parameters\n", + "\n", + "The following Ollama parameters can be added to your config. See [this link](https://github.com/ollama/ollama/blob/main/docs/api.md#parameters) for further information on them.\n", + "\n", + "- num_predict (integer): -1 is infinite, -2 is fill context, 128 is default\n", + "- repeat_penalty (float)\n", + "- seed (integer)\n", + "- stream (boolean)\n", + "- temperature (float)\n", + "- top_k (int)\n", + "- top_p (float)\n", + "\n", + "Example:\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1:instruct\",\n", + " \"api_type\": \"ollama\",\n", + " \"num_predict\": -1,\n", + " \"repeat_penalty\": 1.1,\n", + " \"seed\": 42,\n", + " \"stream\": False,\n", + " \"temperature\": 1,\n", + " \"top_k\": 50,\n", + " \"top_p\": 0.8\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two-Agent Coding Example\n", + "\n", + "In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n", + "\n", + "We'll use Meta's Llama 3.1 model which is suitable for coding.\n", + "\n", + "In this example we will specify the URL for the Ollama installation using `client_host`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "config_list = [\n", + " {\n", + " # Let's choose the Meta's Llama 3.1 model (model names must match Ollama exactly)\n", + " \"model\": \"llama3.1\",\n", + " # We specify the API Type as 'ollama' so it uses the Ollama client class\n", + " \"api_type\": \"ollama\",\n", + " \"stream\": False,\n", + " \"client_host\": \"http://192.168.0.1:11434\",\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from autogen.coding import LocalCommandLineCodeExecutor\n", + "\n", + "# Setting up the code executor\n", + "workdir = Path(\"coding\")\n", + "workdir.mkdir(exist_ok=True)\n", + "code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n", + "\n", + "# Setting up the agents\n", + "\n", + "# The UserProxyAgent will execute the code that the AssistantAgent provides\n", + "user_proxy_agent = UserProxyAgent(\n", + " name=\"User\",\n", + " code_execution_config={\"executor\": code_executor},\n", + " is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n", + ")\n", + "\n", + "system_message = \"\"\"You are a helpful AI assistant who writes code and the user\n", + "executes it. Solve tasks using your python coding skills.\n", + "In the following cases, suggest python code (in a python coding block) for the\n", + "user to execute. When using code, you must indicate the script type in the code block.\n", + "You only need to create one working sample.\n", + "Do not suggest incomplete code which requires users to modify it.\n", + "Don't use a code block if it's not intended to be executed by the user. Don't\n", + "include multiple code blocks in one response. Do not ask users to copy and\n", + "paste the result. Instead, use 'print' function for the output when relevant.\n", + "Check the execution result returned by the user.\n", + "\n", + "If the result indicates there is an error, fix the error.\n", + "\n", + "IMPORTANT: If it has executed successfully, ONLY output 'FINISH'.\"\"\"\n", + "\n", + "# The AssistantAgent, using the Ollama config, will take the coding request and return code\n", + "assistant_agent = AssistantAgent(\n", + " name=\"Ollama Assistant\",\n", + " system_message=system_message,\n", + " llm_config={\"config_list\": config_list},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now start the chat." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser\u001b[0m (to Ollama Assistant):\n", + "\n", + "Provide code to count the number of prime numbers from 1 to 10000.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mOllama Assistant\u001b[0m (to User):\n", + "\n", + "```python\n", + "def is_prime(n):\n", + " \"\"\"Check if a number is prime.\"\"\"\n", + " if n < 2:\n", + " return False\n", + " for i in range(2, int(n**0.5) + 1):\n", + " if n % i == 0:\n", + " return False\n", + " return True\n", + "\n", + "\n", + "def count_primes():\n", + " \"\"\"Count the number of prime numbers from 1 to 10000.\"\"\"\n", + " count = sum(1 for num in range(1, 10001) if is_prime(num))\n", + " print(count)\n", + "\n", + "\n", + "# Execute the function\n", + "count_primes()\n", + "```\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n", + "\u001b[31m\n", + ">>>>>>>> USING AUTO REPLY...\u001b[0m\n", + "\u001b[31m\n", + ">>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...\u001b[0m\n", + "\u001b[33mUser\u001b[0m (to Ollama Assistant):\n", + "\n", + "exitcode: 0 (execution succeeded)\n", + "Code output: 1229\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mOllama Assistant\u001b[0m (to User):\n", + "\n", + "FINISH\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" + ] + } + ], + "source": [ + "# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n", + "chat_result = user_proxy_agent.initiate_chat(\n", + " assistant_agent,\n", + " message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Call Example\n", + "\n", + "In this example, instead of writing code, we will have an agent assist with some trip planning using multiple tool calling.\n", + "\n", + "Again, we'll use Meta's versatile Llama 3.1." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Literal\n", + "\n", + "from typing_extensions import Annotated\n", + "\n", + "import autogen\n", + "\n", + "config_list = [\n", + " {\n", + " # Let's choose the Meta's Llama 3.1 model (model names must match Ollama exactly)\n", + " \"model\": \"llama3.1\",\n", + " \"api_type\": \"ollama\",\n", + " \"stream\": False,\n", + " \"client_host\": \"http://192.168.0.1:11434\",\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll create our agents" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the agent for tool calling\n", + "chatbot = autogen.AssistantAgent(\n", + " name=\"chatbot\",\n", + " system_message=\"\"\"For currency exchange and weather forecasting tasks,\n", + " only use the functions you have been provided with.\n", + " Output 'HAVE FUN!' when an answer has been provided.\"\"\",\n", + " llm_config={\"config_list\": config_list},\n", + ")\n", + "\n", + "# Note that we have changed the termination string to be \"HAVE FUN!\"\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create and register our functions (tools). See the [tutorial chapter on tool use](/docs/tutorial/tool-use) \n", + "for more information." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Currency Exchange function\n", + "\n", + "CurrencySymbol = Literal[\"USD\", \"EUR\"]\n", + "\n", + "# Define our function that we expect to call\n", + "\n", + "\n", + "def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n", + " if base_currency == quote_currency:\n", + " return 1.0\n", + " elif base_currency == \"USD\" and quote_currency == \"EUR\":\n", + " return 1 / 1.1\n", + " elif base_currency == \"EUR\" and quote_currency == \"USD\":\n", + " return 1.1\n", + " else:\n", + " raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n", + "def currency_calculator(\n", + " base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n", + " base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n", + " quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n", + ") -> str:\n", + " quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n", + " return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n", + "\n", + "\n", + "# Weather function\n", + "\n", + "\n", + "# Example function to make available to model\n", + "def get_current_weather(location, unit=\"fahrenheit\"):\n", + " \"\"\"Get the weather for some location\"\"\"\n", + " if \"chicago\" in location.lower():\n", + " return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n", + " elif \"san francisco\" in location.lower():\n", + " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n", + " elif \"new york\" in location.lower():\n", + " return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n", + " else:\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n", + "def weather_forecast(\n", + " location: Annotated[str, \"City name\"],\n", + ") -> str:\n", + " weather_details = get_current_weather(location=location)\n", + " weather = json.loads(weather_details)\n", + " return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And run it!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "\n", + "\u001b[32m***** Suggested tool call (ollama_func_2863): weather_forecast *****\u001b[0m\n", + "Arguments: \n", + "{\"location\": \"New York\"}\n", + "\u001b[32m********************************************************************\u001b[0m\n", + "\u001b[32m***** Suggested tool call (ollama_func_2864): currency_calculator *****\u001b[0m\n", + "Arguments: \n", + "{\"base_amount\": 123.45, \"quote_currency\": \"USD\", \"base_currency\": \"EUR\"}\n", + "\u001b[32m***********************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (ollama_func_2863) *****\u001b[0m\n", + "New York will be 11 degrees fahrenheit\n", + "\u001b[32m*********************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (ollama_func_2864) *****\u001b[0m\n", + "135.80 USD\n", + "\u001b[32m*********************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "So it's going to be a chilly winter in New York! \n", + "\n", + "Now, let's talk about your holiday expenses. You've got $135.80 USD to spend in New York, which is great for exploring the city. Here are some tips:\n", + "\n", + "* Make sure to try a classic NYC hot dog from a street vendor - it's a must-try!\n", + "* Take a stroll across the Brooklyn Bridge for stunning views of the Manhattan skyline.\n", + "* Visit the iconic Central Park and take a leisurely walk through the gardens.\n", + "* Don't miss out on trying some delicious pizza slices from one of the many pizzerias in Little Italy.\n", + "\n", + "Have fun in New York!\n", + "\n", + "--------------------------------------------------------------------------------\n", + "LLM SUMMARY: New York will be 11 degrees Fahrenheit. $135.80 USD is equivalent to €123.45 EUR. Consider trying a classic NYC hot dog, walking across the Brooklyn Bridge, visiting Central Park, and trying pizza slices in Little Italy during your holiday.\n" + ] + } + ], + "source": [ + "# start the conversation\n", + "res = user_proxy.initiate_chat(\n", + " chatbot,\n", + " message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n", + " summary_method=\"reflection_with_llm\",\n", + ")\n", + "\n", + "print(f\"LLM SUMMARY: {res.summary['content']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Great, we can see that Llama 3.1 has helped choose the right functions, their parameters, and then summarised them for us." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 0821e83b12e04b298385443fda031ec91521ce29 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 27 Jul 2024 01:10:22 +0000 Subject: [PATCH 09/14] Added Native Ollama tool calling (v0.3.0 req.) as well as hide/show tools support --- autogen/oai/ollama.py | 154 ++++++++++++++++++++++++++++-------------- 1 file changed, 104 insertions(+), 50 deletions(-) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index a9155b66835..ad5bf3c8a7e 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -23,6 +23,7 @@ import random import re import time +import warnings from typing import Any, Dict, List, Tuple import ollama @@ -146,6 +147,23 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: if "top_p" in params: ollama_params["top_p"] = validate_parameter(params, "top_p", (int, float), False, 0.9, None, None) + if self._native_tool_calls and self._tools_in_conversation and not self._should_hide_tools: + ollama_params["tools"] = params["tools"] + + # Ollama doesn't support streaming with tools natively + if ollama_params["stream"] and self._native_tool_calls: + warnings.warn( + "Streaming is not supported when using tools and 'Native' tool calling, streaming will be disabled.", + UserWarning, + ) + + ollama_params["stream"] = False + + if not self._native_tool_calls and self._tools_in_conversation: + # For manual tool calling we have injected the available tools into the prompt + # and we don't want to force JSON mode + ollama_params["format"] = "" # Don't force JSON for manual tool calling mode + if len(options_dict) != 0: ollama_params["options"] = options_dict @@ -158,13 +176,22 @@ def create(self, params: Dict) -> ChatCompletion: # Are tools involved in this conversation? self._tools_in_conversation = "tools" in params - # Function/Tool calling options - # For the time-being Ollama does not support tool calling, so we will handle this - # manually by providing guidance to the LLM and parsing responses to look for tool calls - # This variable could be omitted but I think it is useful to keep in for now. - self._tool_calling_mode = "manual" + # We provide second-level filtering out of tools to avoid LLMs re-calling tools continuously + if self._tools_in_conversation: + hide_tools = validate_parameter( + params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"] + ) + self._should_hide_tools = should_hide_tools(messages, params["tools"], hide_tools) + else: + self._should_hide_tools = False - if self._tool_calling_mode == "manual": + # Are we using native Ollama tool calling, otherwise we're doing manual tool calling + # We allow the user to decide if they want to use Ollama's tool calling + # or for tool calling to be handled manually through text messages + # Default is True = Ollama's tool calling + self._native_tool_calls = validate_parameter(params, "native_tool_calls", bool, False, True, None, None) + + if not self._native_tool_calls: # Load defaults self._manual_tool_call_instruction = validate_parameter( params, "manual_tool_call_instruction", str, False, self.TOOL_CALL_MANUAL_INSTRUCTION, None, None @@ -178,17 +205,17 @@ def create(self, params: Dict) -> ChatCompletion: # Convert AutoGen messages to Ollama messages ollama_messages = self.oai_messages_to_ollama_messages( - messages, params["tools"] if self._tools_in_conversation else None + messages, + ( + params["tools"] + if (not self._native_tool_calls and self._tools_in_conversation) and not self._should_hide_tools + else None + ), ) # Parse parameters to the Ollama API's parameters ollama_params = self.parse_params(params) - # Add tools to the call if we have them and aren't hiding them - if self._tools_in_conversation: - # For Ollama we will inject the available tools into the prompt - ollama_params["format"] = "" # Don't force JSON for manual tool calling mode - ollama_params["messages"] = ollama_messages # Token counts will be returned @@ -228,55 +255,81 @@ def create(self, params: Dict) -> ChatCompletion: if response is not None: + # Defaults + ollama_finish = "stop" + tool_calls = None + + # Id and streaming text into response if ollama_params["stream"]: response_content = ans response_id = chunk["created_at"] else: + response_content = response["message"]["content"] response_id = response["created_at"] - # Are we doing a manual tool call - is_manual_tool_calling = False - - if self._tools_in_conversation and self._tool_calling_mode == "manual": - # Try to convert the response to a tool call object - response_toolcalls = response_to_tool_call(ans) - - # If we can, then it's a manual tool call - if response_toolcalls is not None: - ollama_finish = "tool_calls" - tool_calls = [] - random_id = random.randint(0, 10000) - - for json_function in response_toolcalls: - tool_calls.append( - ChatCompletionMessageToolCall( - id="ollama_func_{}".format(random_id), - function={ - "name": json_function["name"], - "arguments": ( - json.dumps(json_function["arguments"]) if "arguments" in json_function else "{}" - ), - }, - type="function", + # Process tools in the response + if self._tools_in_conversation: + + if self._native_tool_calls: + + if not ollama_params["stream"]: + response_content = response["message"]["content"] + + # Native tool calling + if "tool_calls" in response["message"]: + ollama_finish = "tool_calls" + tool_calls = [] + random_id = random.randint(0, 10000) + for tool_call in response["message"]["tool_calls"]: + tool_calls.append( + ChatCompletionMessageToolCall( + id="ollama_func_{}".format(random_id), + function={ + "name": tool_call["function"]["name"], + "arguments": json.dumps(tool_call["function"]["arguments"]), + }, + type="function", + ) + ) + + random_id += 1 + + elif not self._native_tool_calls: + + # Try to convert the response to a tool call object + response_toolcalls = response_to_tool_call(ans) + + # If we can, then we've got tool call(s) + if response_toolcalls is not None: + ollama_finish = "tool_calls" + tool_calls = [] + random_id = random.randint(0, 10000) + + for json_function in response_toolcalls: + tool_calls.append( + ChatCompletionMessageToolCall( + id="ollama_manual_func_{}".format(random_id), + function={ + "name": json_function["name"], + "arguments": ( + json.dumps(json_function["arguments"]) + if "arguments" in json_function + else "{}" + ), + }, + type="function", + ) ) - ) - - random_id += 1 - is_manual_tool_calling = True + random_id += 1 - # Blank the message content - response_content = "" + # Blank the message content + response_content = "" - if not is_manual_tool_calling: - if not ollama_params["stream"]: - response_content = response["message"]["content"] - ollama_finish = "stop" - tool_calls = None else: - raise RuntimeError("Failed to get response from Ollama after retrying 5 times.") + raise RuntimeError("Failed to get response from Ollama.") - # 3. convert output + # Convert response to AutoGen response message = ChatCompletionMessage( role="assistant", content=response_content, @@ -319,7 +372,7 @@ def oai_messages_to_ollama_messages(self, messages: list[Dict[str, Any]], tools: ollama_messages[-1]["role"] = "user" # Process messages for tool calling manually - if self._tools_in_conversation and self._tool_calling_mode == "manual": + if tools is not None and not self._native_tool_calls: # 1. We need to append instructions to the starting system message on function calling # 2. If we have not yet called tools we append "step 1 instruction" to the latest user message # 3. If we have already called tools we append "step 2 instruction" to the latest user message @@ -389,6 +442,7 @@ def oai_messages_to_ollama_messages(self, messages: list[Dict[str, Any]], tools: # As we are changing messages, let's merge if they have two user messages on the end and the last one is tool call step instructions if ( len(ollama_messages) >= 2 + and not self._native_tool_calls and ollama_messages[-2]["role"] == "user" and ollama_messages[-1]["role"] == "user" and ( From c0f5f6db91c9f0f5874ad24de90b079e2e80329d Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 27 Jul 2024 02:05:28 +0000 Subject: [PATCH 10/14] Added native tool calling and hide_tools parameter to documentation --- .../non-openai-models/local-ollama.ipynb | 144 ++++++++++++++---- 1 file changed, 112 insertions(+), 32 deletions(-) diff --git a/website/docs/topics/non-openai-models/local-ollama.ipynb b/website/docs/topics/non-openai-models/local-ollama.ipynb index 7c1e77df6a6..693f3f1a734 100644 --- a/website/docs/topics/non-openai-models/local-ollama.ipynb +++ b/website/docs/topics/non-openai-models/local-ollama.ipynb @@ -146,7 +146,7 @@ "config_list = [\n", " {\n", " # Let's choose the Meta's Llama 3.1 model (model names must match Ollama exactly)\n", - " \"model\": \"llama3.1\",\n", + " \"model\": \"llama3.1:8b\",\n", " # We specify the API Type as 'ollama' so it uses the Ollama client class\n", " \"api_type\": \"ollama\",\n", " \"stream\": False,\n", @@ -244,25 +244,19 @@ "\n", "```python\n", "def is_prime(n):\n", - " \"\"\"Check if a number is prime.\"\"\"\n", - " if n < 2:\n", + " if n <= 1:\n", " return False\n", " for i in range(2, int(n**0.5) + 1):\n", " if n % i == 0:\n", " return False\n", " return True\n", "\n", - "\n", - "def count_primes():\n", - " \"\"\"Count the number of prime numbers from 1 to 10000.\"\"\"\n", - " count = sum(1 for num in range(1, 10001) if is_prime(num))\n", - " print(count)\n", - "\n", - "\n", - "# Execute the function\n", - "count_primes()\n", + "count = sum(is_prime(i) for i in range(1, 10001))\n", + "print(count)\n", "```\n", "\n", + "Please execute this code. I will wait for the result.\n", + "\n", "--------------------------------------------------------------------------------\n", "\u001b[31m\n", ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n", @@ -295,6 +289,69 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Calling - Native vs Manual\n", + "\n", + "Ollama supports native tool calling (Ollama v0.3.0 library onward). If you install AutoGen with `pip install pyautogen[ollama]` you will be able to use native tool calling.\n", + "\n", + "The parameter `native_tool_calls` in your configuration allows you to specify if you want to use Ollama's native tool calling (default) or manual tool calling.\n", + "\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1\",\n", + " \"api_type\": \"ollama\",\n", + " \"client_host\": \"http://192.168.0.1:11434\",\n", + " \"native_tool_calls\": True # Use Ollama's native tool calling, False for manual\n", + " }\n", + "]\n", + "```\n", + "\n", + "Native tool calling only works with certain models and an exception will be thrown if you try to use it with an unsupported model.\n", + "\n", + "Manual tool calling allows you to use tool calling with any Ollama model. It incorporates guided tool calling messages into the prompt that guide the LLM through the process of selecting a tool and then evaluating the result of the tool. As to be expected, the ability to follow instructions and return formatted JSON is highly dependent on the model.\n", + "\n", + "You can tailor the manual tool calling messages by adding these parameters to your configuration:\n", + "\n", + "- `manual_tool_call_instruction`\n", + "- `manual_tool_call_step1`\n", + "- `manual_tool_call_step2`\n", + "\n", + "To use manual tool calling set `native_tool_calls` to `False`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reducing repetitive tool calls\n", + "\n", + "By incorporating tools into a conversation, LLMs can often continually recommend them to be called, even after they've been called and a result returned. This can lead to a never ending cycle of tool calls.\n", + "\n", + "To remove the chance of an LLM recommending a tool call, an additional parameter called `hide_tools` can be used to specify when tools are hidden from the LLM. The string values for the parameter are:\n", + "\n", + "- 'never': tools are never hidden\n", + "- 'if_all_run': tools are hidden if all tools have been called\n", + "- 'if_any_run': tools are hidden if any tool has been called\n", + "\n", + "This can be used with native or manual tool calling, an example of a configuration is shown below.\n", + "\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1\",\n", + " \"api_type\": \"ollama\",\n", + " \"client_host\": \"http://192.168.0.1:11434\",\n", + " \"native_tool_calls\": True,\n", + " \"hide_tools\": \"if_any_run\" # Hide tools once any tool has been called\n", + " }\n", + "]\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -303,7 +360,9 @@ "\n", "In this example, instead of writing code, we will have an agent assist with some trip planning using multiple tool calling.\n", "\n", - "Again, we'll use Meta's versatile Llama 3.1." + "Again, we'll use Meta's versatile Llama 3.1.\n", + "\n", + "Native Ollama tool calling will be used and we'll utilise the `hide_tools` parameter to hide the tools once all have been called." ] }, { @@ -322,10 +381,11 @@ "config_list = [\n", " {\n", " # Let's choose the Meta's Llama 3.1 model (model names must match Ollama exactly)\n", - " \"model\": \"llama3.1\",\n", + " \"model\": \"llama3.1:8b\",\n", " \"api_type\": \"ollama\",\n", " \"stream\": False,\n", " \"client_host\": \"http://192.168.0.1:11434\",\n", + " \"hide_tools\": \"if_any_run\",\n", " }\n", "]" ] @@ -334,12 +394,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We'll create our agents" + "We'll create our agents. Importantly, we're using native Ollama tool calling and to help guide it we add the JSON to the system_message so that the number fields aren't wrapped in quotes (becoming strings)." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -348,6 +408,19 @@ " name=\"chatbot\",\n", " system_message=\"\"\"For currency exchange and weather forecasting tasks,\n", " only use the functions you have been provided with.\n", + " Example of the return JSON is:\n", + " {\n", + " \"parameter_1_name\": 100.00,\n", + " \"parameter_2_name\": \"ABC\",\n", + " \"parameter_3_name\": \"DEF\",\n", + " }.\n", + " Another example of the return JSON is:\n", + " {\n", + " \"parameter_1_name\": \"GHI\",\n", + " \"parameter_2_name\": \"ABC\",\n", + " \"parameter_3_name\": \"DEF\",\n", + " \"parameter_4_name\": 123.00,\n", + " }.\n", " Output 'HAVE FUN!' when an answer has been provided.\"\"\",\n", " llm_config={\"config_list\": config_list},\n", ")\n", @@ -371,7 +444,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -399,7 +472,10 @@ "@user_proxy.register_for_execution()\n", "@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n", "def currency_calculator(\n", - " base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n", + " base_amount: Annotated[\n", + " float,\n", + " \"Amount of currency in base_currency. Type is float, not string, return value should be a number only, e.g. 987.65.\",\n", + " ],\n", " base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n", " quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n", ") -> str:\n", @@ -445,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -460,13 +536,13 @@ "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\n", "\n", - "\u001b[32m***** Suggested tool call (ollama_func_2863): weather_forecast *****\u001b[0m\n", + "\u001b[32m***** Suggested tool call (ollama_func_4506): weather_forecast *****\u001b[0m\n", "Arguments: \n", "{\"location\": \"New York\"}\n", "\u001b[32m********************************************************************\u001b[0m\n", - "\u001b[32m***** Suggested tool call (ollama_func_2864): currency_calculator *****\u001b[0m\n", + "\u001b[32m***** Suggested tool call (ollama_func_4507): currency_calculator *****\u001b[0m\n", "Arguments: \n", - "{\"base_amount\": 123.45, \"quote_currency\": \"USD\", \"base_currency\": \"EUR\"}\n", + "{\"base_amount\": 123.45, \"base_currency\": \"EUR\", \"quote_currency\": \"USD\"}\n", "\u001b[32m***********************************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", @@ -478,33 +554,37 @@ "\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\n", - "\u001b[32m***** Response from calling tool (ollama_func_2863) *****\u001b[0m\n", + "\u001b[32m***** Response from calling tool (ollama_func_4506) *****\u001b[0m\n", "New York will be 11 degrees fahrenheit\n", "\u001b[32m*********************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", "\n", - "\u001b[32m***** Response from calling tool (ollama_func_2864) *****\u001b[0m\n", + "\u001b[32m***** Response from calling tool (ollama_func_4507) *****\u001b[0m\n", "135.80 USD\n", "\u001b[32m*********************************************************\u001b[0m\n", "\n", "--------------------------------------------------------------------------------\n", "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", "\n", - "So it's going to be a chilly winter in New York! \n", + "Based on the results, it seems that:\n", + "\n", + "* The weather forecast for New York is expected to be around 11 degrees Fahrenheit.\n", + "* The exchange rate for EUR to USD is currently 1 EUR = 1.3580 USD, so 123.45 EUR is equivalent to approximately 135.80 USD.\n", "\n", - "Now, let's talk about your holiday expenses. You've got $135.80 USD to spend in New York, which is great for exploring the city. Here are some tips:\n", + "As a bonus, here are some holiday tips in New York:\n", "\n", - "* Make sure to try a classic NYC hot dog from a street vendor - it's a must-try!\n", - "* Take a stroll across the Brooklyn Bridge for stunning views of the Manhattan skyline.\n", - "* Visit the iconic Central Park and take a leisurely walk through the gardens.\n", - "* Don't miss out on trying some delicious pizza slices from one of the many pizzerias in Little Italy.\n", + "* Be sure to try a classic New York-style hot dog from a street cart or a diner.\n", + "* Explore the iconic Central Park and take a stroll through the High Line for some great views of the city.\n", + "* Catch a Broadway show or a concert at one of the many world-class venues in the city.\n", "\n", - "Have fun in New York!\n", + "And... HAVE FUN!\n", "\n", "--------------------------------------------------------------------------------\n", - "LLM SUMMARY: New York will be 11 degrees Fahrenheit. $135.80 USD is equivalent to €123.45 EUR. Consider trying a classic NYC hot dog, walking across the Brooklyn Bridge, visiting Central Park, and trying pizza slices in Little Italy during your holiday.\n" + "LLM SUMMARY: The weather forecast for New York is expected to be around 11 degrees Fahrenheit.\n", + "123.45 EUR is equivalent to approximately 135.80 USD.\n", + "Try a classic New York-style hot dog, explore Central Park and the High Line, and catch a Broadway show or concert during your visit.\n" ] } ], From f118d0ec7a0595fdfe3ecd7b4702458797c2fc05 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 30 Jul 2024 04:45:06 +0000 Subject: [PATCH 11/14] Update to Ollama 0.3.1, added tests --- autogen/oai/ollama.py | 2 +- setup.py | 2 +- test/oai/test_ollama.py | 292 +++++++++++++++++- .../non-openai-models/local-ollama.ipynb | 2 +- 4 files changed, 293 insertions(+), 5 deletions(-) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index ad5bf3c8a7e..81075b90565 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -116,7 +116,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: ollama_params["model"] = params.get("model", None) assert ollama_params[ "model" - ], "Please specify the 'model' in your config list entry to nominate the Ollama model to use. The model must start with 'ollama/' or 'ollama_chat/'." + ], "Please specify the 'model' in your config list entry to nominate the Ollama model to use." ollama_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None) diff --git a/setup.py b/setup.py index 3ca1384c85a..7655e3d352c 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,7 @@ "mistral": ["mistralai>=0.2.0"], "groq": ["groq>=0.9.0"], "cohere": ["cohere>=5.5.8"], - "ollama": ["ollama>=0.3.0", "fix_busted_json>=0.0.18"], + "ollama": ["ollama>=0.3.1", "fix_busted_json>=0.0.18"], } setuptools.setup( diff --git a/test/oai/test_ollama.py b/test/oai/test_ollama.py index edb6ba041a2..729e1b95d81 100644 --- a/test/oai/test_ollama.py +++ b/test/oai/test_ollama.py @@ -3,7 +3,7 @@ import pytest try: - from autogen.oai.ollama import OllamaClient + from autogen.oai.ollama import OllamaClient, response_to_tool_call skip = False except ImportError: @@ -11,4 +11,292 @@ InternalServerError = object skip = True -# TODO + +# Fixtures for mock data +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, text, choices, usage, cost, model): + self.text = text + self.choices = choices + self.usage = usage + self.cost = cost + self.model = model + + return MockResponse + + +@pytest.fixture +def ollama_client(): + + # Set Ollama client with some default values + client = OllamaClient() + + client._native_tool_calls = True + client._tools_in_conversation = False + + return client + + +skip_reason = "Ollama dependency is not installed" + + +# Test initialization and configuration +@pytest.mark.skipif(skip, reason=skip_reason) +def test_initialization(): + + # Creation works without an api_key + OllamaClient() + + +# Test parameters +@pytest.mark.skipif(skip, reason=skip_reason) +def test_parsing_params(ollama_client): + # All parameters (with default values) + params = { + "model": "llama3.1:8b", + "temperature": 0.8, + "num_predict": 128, + "repeat_penalty": 1.1, + "seed": 42, + "top_k": 40, + "top_p": 0.9, + "stream": False, + } + expected_params = { + "model": "llama3.1:8b", + "temperature": 0.8, + "num_predict": 128, + "top_k": 40, + "top_p": 0.9, + "options": { + "repeat_penalty": 1.1, + "seed": 42, + }, + "stream": False, + } + result = ollama_client.parse_params(params) + assert result == expected_params + + # Incorrect types, defaults should be set, will show warnings but not trigger assertions + params = { + "model": "llama3.1:8b", + "temperature": "0.5", + "num_predict": "128", + "repeat_penalty": "1.1", + "seed": "42", + "top_k": "40", + "top_p": "0.9", + "stream": "True", + } + result = ollama_client.parse_params(params) + assert result == expected_params + + # Only model, others set as defaults if they are mandatory + params = { + "model": "llama3.1:8b", + } + expected_params = {"model": "llama3.1:8b", "stream": False} + result = ollama_client.parse_params(params) + assert result == expected_params + + # No model + params = { + "temperature": 0.8, + } + + with pytest.raises(AssertionError) as assertinfo: + result = ollama_client.parse_params(params) + + assert "Please specify the 'model' in your config list entry to nominate the Ollama model to use." in str( + assertinfo.value + ) + + +# Test text generation +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.ollama.OllamaClient.create") +def test_create_response(mock_chat, ollama_client): + # Mock OllamaClient.chat response + mock_ollama_response = MagicMock() + mock_ollama_response.choices = [ + MagicMock(finish_reason="stop", message=MagicMock(content="Example Ollama response", tool_calls=None)) + ] + mock_ollama_response.id = "mock_ollama_response_id" + mock_ollama_response.model = "llama3.1:8b" + mock_ollama_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage + + mock_chat.return_value = mock_ollama_response + + # Test parameters + params = { + "messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}], + "model": "llama3.1:8b", + } + + # Call the create method + response = ollama_client.create(params) + + # Assertions to check if response is structured as expected + assert ( + response.choices[0].message.content == "Example Ollama response" + ), "Response content should match expected output" + assert response.id == "mock_ollama_response_id", "Response ID should match the mocked response ID" + assert response.model == "llama3.1:8b", "Response model should match the mocked response model" + assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage" + assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage" + + +# Test functions/tools +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.ollama.OllamaClient.create") +def test_create_response_with_tool_call(mock_chat, ollama_client): + # Mock OllamaClient.chat response + mock_function = MagicMock(name="currency_calculator") + mock_function.name = "currency_calculator" + mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}' + + mock_function_2 = MagicMock(name="get_weather") + mock_function_2.name = "get_weather" + mock_function_2.arguments = '{"location": "New York"}' + + mock_chat.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="tool_calls", + message=MagicMock( + content="Sample text about the functions", + tool_calls=[ + MagicMock(id="gdRdrvnHh", function=mock_function), + MagicMock(id="abRdrvnHh", function=mock_function_2), + ], + ), + ) + ], + id="mock_ollama_response_id", + model="llama3.1:8b", + usage=MagicMock(prompt_tokens=10, completion_tokens=20), + ) + + # Construct parameters + converted_functions = [ + { + "type": "function", + "function": { + "description": "Currency exchange calculator.", + "name": "currency_calculator", + "parameters": { + "type": "object", + "properties": { + "base_amount": {"type": "number", "description": "Amount of currency in base_currency"}, + }, + "required": ["base_amount"], + }, + }, + } + ] + ollama_messages = [ + {"role": "user", "content": "How much is 123.45 EUR in USD?"}, + {"role": "assistant", "content": "World"}, + ] + + # Call the create method + response = ollama_client.create({"messages": ollama_messages, "tools": converted_functions, "model": "llama3.1:8b"}) + + # Assertions to check if the functions and content are included in the response + assert response.choices[0].message.content == "Sample text about the functions" + assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator" + assert response.choices[0].message.tool_calls[1].function.name == "get_weather" + + +# Test function parsing with manual tool calling +@pytest.mark.skipif(skip, reason=skip_reason) +def test_manual_tool_calling_parsing(ollama_client): + # Test the parsing of a tool call within the response content (fully correct) + response_content = """[{"name": "weather_forecast", "arguments":{"location": "New York"}},{"name": "currency_calculator", "arguments":{"base_amount": 123.45, "quote_currency": "EUR", "base_currency": "USD"}}]""" + + response_tool_calls = response_to_tool_call(response_content) + + expected_tool_calls = [ + {"name": "weather_forecast", "arguments": {"location": "New York"}}, + { + "name": "currency_calculator", + "arguments": {"base_amount": 123.45, "quote_currency": "EUR", "base_currency": "USD"}, + }, + ] + + assert ( + response_tool_calls == expected_tool_calls + ), "Manual Tool Calling Parsing of response did not yield correct tool_calls (full string match)" + + # Test the parsing with a substring containing the response content (should still pass) + response_content = """I will call two functions, weather_forecast and currency_calculator:\n[{"name": "weather_forecast", "arguments":{"location": "New York"}},{"name": "currency_calculator", "arguments":{"base_amount": 123.45, "quote_currency": "EUR", "base_currency": "USD"}}]""" + + response_tool_calls = response_to_tool_call(response_content) + + assert ( + response_tool_calls == expected_tool_calls + ), "Manual Tool Calling Parsing of response did not yield correct tool_calls (partial string match)" + + # Test the parsing with an invalid function call + response_content = """[{"function": "weather_forecast", "args":{"location": "New York"}},{"function": "currency_calculator", "args":{"base_amount": 123.45, "quote_currency": "EUR", "base_currency": "USD"}}]""" + + response_tool_calls = response_to_tool_call(response_content) + + assert ( + response_tool_calls is None + ), "Manual Tool Calling Parsing of response did not yield correct tool_calls (invalid function call)" + + # Test the parsing with plain text + response_content = """Call the weather_forecast function and pass in 'New York' as the 'location' argument.""" + + response_tool_calls = response_to_tool_call(response_content) + + assert ( + response_tool_calls is None + ), "Manual Tool Calling Parsing of response did not yield correct tool_calls (no function in text)" + + +# Test message conversion from OpenAI to Ollama format +@pytest.mark.skipif(skip, reason=skip_reason) +def test_oai_messages_to_ollama_messages(ollama_client): + # Test that the "name" key is removed + test_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "name": "anne", "content": "Why is the sky blue?"}, + ] + messages = ollama_client.oai_messages_to_ollama_messages(test_messages, None) + + expected_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "content": "Why is the sky blue?"}, + ] + + assert messages == expected_messages, "'name' was not removed from messages" + + # Test that there isn't a final system message and it's changed to user + test_messages.append({"role": "system", "content": "Summarise the conversation."}) + + messages = ollama_client.oai_messages_to_ollama_messages(test_messages, None) + + expected_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "content": "Why is the sky blue?"}, + {"role": "user", "content": "Summarise the conversation."}, + ] + + assert messages == expected_messages, "Final 'system' message was not changed to 'user'" + + # Test that the last message is a user or system message and if not, add a continue message + test_messages[2] = {"role": "assistant", "content": "The sky is blue because that's a great colour."} + + messages = ollama_client.oai_messages_to_ollama_messages(test_messages, None) + + expected_messages = [ + {"role": "system", "content": "You are a helpful AI bot."}, + {"role": "user", "content": "Why is the sky blue?"}, + {"role": "assistant", "content": "The sky is blue because that's a great colour."}, + {"role": "user", "content": "Please continue."}, + ] + + assert messages == expected_messages, "'Please continue' message was not appended." diff --git a/website/docs/topics/non-openai-models/local-ollama.ipynb b/website/docs/topics/non-openai-models/local-ollama.ipynb index 693f3f1a734..95803e50e59 100644 --- a/website/docs/topics/non-openai-models/local-ollama.ipynb +++ b/website/docs/topics/non-openai-models/local-ollama.ipynb @@ -295,7 +295,7 @@ "source": [ "## Tool Calling - Native vs Manual\n", "\n", - "Ollama supports native tool calling (Ollama v0.3.0 library onward). If you install AutoGen with `pip install pyautogen[ollama]` you will be able to use native tool calling.\n", + "Ollama supports native tool calling (Ollama v0.3.1 library onward). If you install AutoGen with `pip install pyautogen[ollama]` you will be able to use native tool calling.\n", "\n", "The parameter `native_tool_calls` in your configuration allows you to specify if you want to use Ollama's native tool calling (default) or manual tool calling.\n", "\n", From a43dac41b65c11dda9bef01f0ad0522066354738 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 30 Jul 2024 23:54:26 +0000 Subject: [PATCH 12/14] Tweak to manual function calling prompt to improve number handling. --- autogen/oai/ollama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autogen/oai/ollama.py b/autogen/oai/ollama.py index 81075b90565..cf154f635da 100644 --- a/autogen/oai/ollama.py +++ b/autogen/oai/ollama.py @@ -54,6 +54,7 @@ class OllamaClient: "or more functions based on the request given and return only JSON with the functions and " "arguments to use. The second step is to analyse the given output of the function and summarise " "it returning only TEXT and not Python or JSON. " + "For argument values, be sure numbers aren't strings, they should not have double quotes around them. " "In terms of your response format, for step 1 return only JSON and NO OTHER text, " "for step 2 return only text and NO JSON/Python/Markdown. " 'The format for running a function is [{"name": "function_name1", "arguments":{"argument_name": "argument_value"}},{"name": "function_name2", "arguments":{"argument_name": "argument_value"}}] ' From a90cfd51ad6cacbc2d8511671cf6dbc4e1bb1628 Mon Sep 17 00:00:00 2001 From: Mark Sze <66362098+marklysze@users.noreply.github.com> Date: Mon, 2 Sep 2024 05:50:25 +1000 Subject: [PATCH 13/14] Update client.py fix indent --- autogen/oai/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 5e4612686fa..8d4ff2e1b9b 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -96,7 +96,7 @@ ollama_import_exception: Optional[ImportError] = None except ImportError as e: ollama_import_exception = e - + try: from autogen.oai.bedrock import BedrockClient From d4d665d5183dcea43da6d5a1ce0401f5e4b7327b Mon Sep 17 00:00:00 2001 From: Mark Sze <66362098+marklysze@users.noreply.github.com> Date: Mon, 2 Sep 2024 18:06:05 +1000 Subject: [PATCH 14/14] Update setup.py - Ollama package version correction --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 48dc0a11cca..911c8ef7f6b 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ "mistral": ["mistralai>=1.0.1"], "groq": ["groq>=0.9.0"], "cohere": ["cohere>=5.5.8"], - "ollama": ["ollama>=0.3.4", "fix_busted_json>=0.0.18"], + "ollama": ["ollama>=0.3.2", "fix_busted_json>=0.0.18"], "bedrock": ["boto3>=1.34.149"], }