From f9716344015f556c50bdc62bc827b382ae819c90 Mon Sep 17 00:00:00 2001 From: smathot Date: Tue, 16 Apr 2024 13:49:42 +0200 Subject: [PATCH] Cheap and expensive unittests work now --- heymans/heymans.py | 14 ++++-- heymans/messages.py | 6 +-- heymans/model/__init__.py | 17 ++++--- heymans/model/_anthropic_model.py | 38 +++++++++------- heymans/model/_base_model.py | 11 +++-- heymans/model/_dummy_model.py | 17 +++++++ heymans/model/_mistral_model.py | 12 ++--- heymans/model/_openai_model.py | 49 ++++++++++++-------- heymans/prompt.py | 8 +--- heymans/tools/_base_tool.py | 27 ++++++++--- tests/cheap/test_tools_json.py | 76 ------------------------------- 11 files changed, 121 insertions(+), 154 deletions(-) create mode 100644 heymans/model/_dummy_model.py delete mode 100644 tests/cheap/test_tools_json.py diff --git a/heymans/heymans.py b/heymans/heymans.py index 6d9f6bb..b7af154 100644 --- a/heymans/heymans.py +++ b/heymans/heymans.py @@ -1,5 +1,6 @@ import logging import jinja2 +import json from types import GeneratorType from typing import Tuple, Optional from . import config, library @@ -147,15 +148,20 @@ def _answer(self, state: str = 'answer') -> GeneratorType: logger.info(f'[{state} state] reply: {reply}') # If the reply is a callable, then it's a tool that we need to run if callable(reply): - tool_message, result, needs_feedback = reply() + tool_message, tool_result, needs_feedback = reply() if needs_feedback: logger.info(f'[{state} state] tools need feedback') + if not self.answer_model.supports_tool_feedback: + logger.info( + f'[{state} state] model does not support feedback') + needs_feedback = False metadata = self.messages.append('assistant', tool_message) yield tool_message, metadata # If the tool has a result, yield and remember it - if result: - metadata = self.messages.append('tool', result) - yield result, metadata + if tool_result: + metadata = self.messages.append('tool', + json.dumps(tool_result)) + yield tool_result['content'], metadata # Otherwise the reply is a regular AI message else: metadata = self.messages.append('assistant', reply) diff --git a/heymans/messages.py b/heymans/messages.py index f3837db..8813fd7 100644 --- a/heymans/messages.py +++ b/heymans/messages.py @@ -146,15 +146,11 @@ def _system_prompt(self): """The system prompt that is used for question answering consists of several fragments. """ - # There is always and identity, information about the current time, - # and a list of attached files + # There is always and identity and a list of attached files if self._heymans.search_first: system_prompt = [prompt.SYSTEM_PROMPT_IDENTITY_WITH_SEARCH] else: system_prompt = [prompt.SYSTEM_PROMPT_IDENTITY_WITHOUT_SEARCH] - system_prompt.append( - prompt.render(prompt.SYSTEM_PROMPT_DATETIME, - current_datetime=utils.current_datetime())) system_prompt.append( attachments.attachments_prompt(self._heymans.database)) # For models that support this, there is also an instruction indicating diff --git a/heymans/model/__init__.py b/heymans/model/__init__.py index b2d34d1..dcefa23 100644 --- a/heymans/model/__init__.py +++ b/heymans/model/__init__.py @@ -1,30 +1,29 @@ from ._base_model import BaseModel -from ._openai_model import OpenAIModel -from ._mistral_model import MistralModel -from ._anthropic_model import AnthropicModel -class DummyModel(BaseModel): - def predict(self, messages): - return 'dummy reply' - - def model(heymans, model, **kwargs): - + """A factory function that returns a Model instance.""" if model == 'gpt-4': + from ._openai_model import OpenAIModel return OpenAIModel(heymans, 'gpt-4-1106-preview', **kwargs) if model == 'gpt-3.5': + from ._openai_model import OpenAIModel return OpenAIModel(heymans, 'gpt-3.5-turbo-1106', **kwargs) if model == 'claude-2.1': + from ._anthropic_model import AnthropicModel return AnthropicModel(heymans, 'claude-2.1', **kwargs) if model == 'claude-3-opus': + from ._anthropic_model import AnthropicModel return AnthropicModel(heymans, 'claude-3-opus-20240229', **kwargs) if model == 'claude-3-sonnet': + from ._anthropic_model import AnthropicModel return AnthropicModel(heymans, 'claude-3-sonnet-20240229', **kwargs) if model.startswith('mistral-'): + from ._mistral_model import MistralModel if not model.endswith('-latest'): model += '-latest' return MistralModel(heymans, model, **kwargs) if model == 'dummy': + from ._dummy_model import DummyModel return DummyModel(heymans, **kwargs) raise ValueError(f'Unknown model: {model}') diff --git a/heymans/model/_anthropic_model.py b/heymans/model/_anthropic_model.py index c6b048b..e0cf8b1 100644 --- a/heymans/model/_anthropic_model.py +++ b/heymans/model/_anthropic_model.py @@ -1,11 +1,14 @@ from . import BaseModel from .. import config, utils import logging +import json logger = logging.getLogger('heymans') class AnthropicModel(BaseModel): + supports_not_done_yet = False + def __init__(self, heymans, model, **kwargs): from anthropic import Anthropic, AsyncAnthropic super().__init__(heymans, **kwargs) @@ -30,38 +33,41 @@ def predict(self, messages): logger.info('entering message postprocessing loop') for i, message in enumerate(messages): if message['role'] == 'tool': + if i == 0: + raise ValueError( + 'The first message cannot be a tool message') logger.info('converting tool message to user message') + tool_info = json.loads(message['content']) message['role'] = 'user' message['content'] = [{ 'type': 'tool_result', 'tool_use_id': str(self._tool_use_id), 'content': [{ 'type': 'text', - 'text': message['content'] + 'text': tool_info['content'] }] }] - if i > 0: - # The previous message needs to have a tool-use block - prev_message = messages[i - 1] - prev_message['content'] = [ - {'type': 'text', - 'text': prev_message['content']}, - {'type': 'tool_use', - 'id': str(self._tool_use_id), - 'input': {'args': 'input args'}, - 'name': 'tool_function' - } - ] + # The previous message needs to have a tool-use block + prev_message = messages[i - 1] + prev_message['content'] = [ + {'type': 'text', + 'text': prev_message['content']}, + {'type': 'tool_use', + 'id': str(self._tool_use_id), + 'input': {'args': tool_info['args']}, + 'name': tool_info['name'] + } + ] self._tool_use_id += 1 if len(messages) > i + 1: - logger.info('merging tool and user message') next_message = messages[i + 1] if next_message['role'] == 'user': + logger.info('merging tool and user message') message['content'].append([{ "type": "text", "text": next_message['content'] }]) - break + break else: break logger.info('dropping duplicate user message') @@ -74,7 +80,7 @@ def get_response(self, response): if block.type == 'tool_use': for tool in self._tools: if tool.name == block.name: - return tool.bind(block.input) + return tool.bind(json.dumps(block.input)) return self.invalid_tool if block.type == 'text': text.append(block.text) diff --git a/heymans/model/_base_model.py b/heymans/model/_base_model.py index 6458e9a..ee69020 100644 --- a/heymans/model/_base_model.py +++ b/heymans/model/_base_model.py @@ -7,8 +7,13 @@ class BaseModel: + """Base implementation for LLM chat models.""" - supports_not_done_yet = False + # Indicates whether the model is able to provide feedback on its own output + supports_not_done_yet = True + # Indicates whether the model is able to provide feedback on tool results + supports_tool_feedback = True + # Approximation to keep track of token costs characters_per_token = 4 def __init__(self, heymans, tools=None, tool_choice='auto'): @@ -20,7 +25,7 @@ def __init__(self, heymans, tools=None, tool_choice='auto'): self.completion_tokens_consumed = 0 def invalid_tool(self) -> str: - return 'Invalid tool' + return 'Invalid tool', None, False def get_response(self, response) -> [str, callable]: return response.content @@ -37,7 +42,7 @@ def async_invoke(self, messages): def messages_length(self, messages) -> int: if isinstance(messages, str): - return lebase_format_toolsn(messages) + return len(messages) return sum([len(m.content if hasattr(m, 'content') else m['content']) for m in messages]) diff --git a/heymans/model/_dummy_model.py b/heymans/model/_dummy_model.py new file mode 100644 index 0000000..0037a1b --- /dev/null +++ b/heymans/model/_dummy_model.py @@ -0,0 +1,17 @@ +from ._base_model import BaseModel + + +class DummyModel(BaseModel): + + def get_response(self, response): + return response + + def invoke(self, messages): + return 'dummy reply' + + async def _async_task(self): + return 'dummy reply' + + def async_invoke(self, messages): + import asyncio + return asyncio.create_task(self._async_task()) diff --git a/heymans/model/_mistral_model.py b/heymans/model/_mistral_model.py index 5a2e7f2..af83f00 100644 --- a/heymans/model/_mistral_model.py +++ b/heymans/model/_mistral_model.py @@ -1,5 +1,6 @@ from .. import config, utils -from . import OpenAIModel, BaseModel +from . import BaseModel +from ._openai_model import OpenAIModel from langchain.schema import SystemMessage, AIMessage, HumanMessage, \ FunctionMessage @@ -7,6 +8,7 @@ class MistralModel(OpenAIModel): supports_not_done_yet = False + supports_tool_feedback = False def __init__(self, heymans, model, **kwargs): from mistralai.async_client import MistralAsyncClient @@ -18,11 +20,6 @@ def __init__(self, heymans, model, **kwargs): self._client = MistralClient(api_key=config.mistral_api_key) self._async_client = MistralAsyncClient(api_key=config.mistral_api_key) - def convert_message(self, message): - from mistralai.models.chat_completion import ChatMessage - message = super().convert_message(message) - return ChatMessage(**message) - def predict(self, messages): if isinstance(messages, str): messages = [self.convert_message(messages)] @@ -30,7 +27,8 @@ def predict(self, messages): messages = utils.prepare_messages(messages, allow_ai_first=False, allow_ai_last=False, merge_consecutive=True) - messages = [self.convert_message(message) for message in messages] + messages = [self.convert_message(message) for message in messages] + messages = self._prepare_tool_messages(messages) return BaseModel.predict(self, messages) def invoke(self, messages): diff --git a/heymans/model/_openai_model.py b/heymans/model/_openai_model.py index 9180caf..9899957 100644 --- a/heymans/model/_openai_model.py +++ b/heymans/model/_openai_model.py @@ -1,11 +1,10 @@ +import json from .. import config from . import BaseModel class OpenAIModel(BaseModel): - supports_not_done_yet = True - def __init__(self, heymans, model, **kwargs): from openai import Client, AsyncClient super().__init__(heymans, **kwargs) @@ -23,26 +22,35 @@ def predict(self, messages): messages = [self.convert_message(messages)] else: messages = [self.convert_message(message) for message in messages] - # OpenAI requires the tool message to be linked to the previous AI - # message with a tool_call_id. The actual content doesn't appear to - # matter, so here we dummy-link the messages - for i, message in enumerate(messages): - if i == 0 or message['role'] != 'tool': - continue - tool_call_id = f'call_{i}' - prev_message = messages[i - 1] - prev_message['tool_calls'] = [ - { - 'id': tool_call_id, - 'type': 'function', - 'function': { - 'name': 'dummy', - 'arguments': '' - } - }] - message['tool_call_id'] = tool_call_id + messages = self._prepare_tool_messages(messages) return super().predict(messages) + def _prepare_tool_messages(self, messages): + # OpenAI requires the tool message to be linked to the previous AI + # message with a tool_call_id. The actual content doesn't appear to + # matter, so here we dummy-link the messages + for i, message in enumerate(messages): + if i == 0 or message['role'] != 'tool': + continue + tool_info = json.loads(message['content']) + tool_call_id = f'call_{i}' + prev_message = messages[i - 1] + # an assistant message should not have both content and tool calls + prev_message['content'] = '' + prev_message['tool_calls'] = [ + { + 'id': tool_call_id, + 'type': 'function', + 'function': { + 'name': tool_info['name'], + 'arguments': tool_info['args'] + } + }] + message['tool_call_id'] = tool_call_id + message['name'] = tool_info['name'] + message['content'] = tool_info['content'] + return messages + def get_response(self, response): tool_calls = response.choices[0].message.tool_calls if tool_calls: @@ -50,6 +58,7 @@ def get_response(self, response): for tool in self._tools: if tool.name == function.name: return tool.bind(function.arguments) + logger.warning(f'invalid tool called: {function}') return self.invalid_tool return response.choices[0].message.content diff --git a/heymans/prompt.py b/heymans/prompt.py index 6d0b395..5f7856f 100644 --- a/heymans/prompt.py +++ b/heymans/prompt.py @@ -6,16 +6,10 @@ # The system prompt used during question answering is composed of the fragments # below SYSTEM_PROMPT_IDENTITY_WITH_SEARCH = '''You are Sigmund, a brilliant AI assistant for users of OpenSesame, a program for building psychology and neuroscience experiments. You sometimes use emojis.''' -SYSTEM_PROMPT_IDENTITY_WITHOUT_SEARCH = '''You are Sigmund, a brilliant AI assistant. You sometimes use emojis. - -If and only if the user asks a question related to OpenSesame, a program for building psychology and neuroscience experiments, don't answer the question. Instead, suggest that the user gives you access to the OpenSesame documentation by enabling OpenSesame expert mode through the menu of this chat web application. -''' +SYSTEM_PROMPT_IDENTITY_WITHOUT_SEARCH = '''You are Sigmund, a brilliant AI assistant. You sometimes use emojis.''' # Sent by AI to indicate that message requires for replies or actions NOT_DONE_YET_MARKER = '' SYSTEM_PROMPT_NOT_DONE_YET = f'''When you intend to perform an action ("please wait", "I will now"), such as searching or code execution, end your reply with {NOT_DONE_YET_MARKER}.''' -SYSTEM_PROMPT_DATETIME = '''# Current date and time - -{{ current_datetime }}''' SYSTEM_PROMPT_ATTACHMENTS = '''# Attachments You have access to the following attached files: diff --git a/heymans/tools/_base_tool.py b/heymans/tools/_base_tool.py index c3b9680..f052524 100644 --- a/heymans/tools/_base_tool.py +++ b/heymans/tools/_base_tool.py @@ -9,13 +9,14 @@ class BaseTool: """A base class for tools that process an AI reply.""" - tool_spec = None - def __init__(self, heymans): self._heymans = heymans @property def tool_spec(self): + """The tool spec should corresond to the OpenAI specification for + function tools. + """ return { "name": self.__class__.__name__, "description": self.__doc__, @@ -30,11 +31,23 @@ def tool_spec(self): def name(self): return self.__class__.__name__ - def bind(self, args): - print(f'binding tool to: {args}') - if isinstance(args, str): - args = json.loads(args) - return functools.partial(self, **args) + def bind(self, args: str) -> callable: + """Returns a callable that corresponds to a tool function called with + a string of arguments, which should be in JSON format. The callable + itself returns a (message: str, result: dict, needs_feedback: bool) + tuple, where message is an informative text message as generated by + the tool, result is a dict with name, args, and content keys that + correspond to the name and arguments of the function and the result of + the tool call. needs_feedback indicates whether the model should be + called again to provide feedback based on the tool result. + """ + def bound_tool_function(): + message, result, needs_feedback = self(**json.loads(args)) + result = {'name': self.name, + 'args': args, + 'content': result} + return message, result, needs_feedback + return bound_tool_function def __call__(self) -> Tuple[str, Optional[str], bool]: """Should be implemented in a tool with additional arguments that diff --git a/tests/cheap/test_tools_json.py b/tests/cheap/test_tools_json.py deleted file mode 100644 index fde22af..0000000 --- a/tests/cheap/test_tools_json.py +++ /dev/null @@ -1,76 +0,0 @@ -from heymans.tools import BaseTool, CodeExecutionTool, GoogleScholarTool - - -def test_tools_json(): - - class TestTool(BaseTool): - json_pattern = CodeExecutionTool.json_pattern - def use(self, message, language, code): - return '', False - - tool = TestTool(None) - - target = 'Some text' - message = '''Some text - -{ - "execute_code": - { - "language": "python", - "code": "pass" - } -}''' - target_message, results, needs_reply = tool.run(message) - assert target_message.strip() == target - message = '''Some text - -```json -{ - "execute_code": - { - "language": "python", - "code": "pass" - } -} -``` -''' - target_message, results, needs_reply = tool.run(message) - assert target_message.strip() == target - target = 'Running `TestTool` … ' - message = '''{ - "execute_code": - { - "language": "python", - "code": "pass" - } -}''' - target_message, results, needs_reply = tool.run(message) - assert target_message.strip() == target - - class TestTool(BaseTool): - json_pattern = GoogleScholarTool.json_pattern - def use(self, message, queries): - return '', False - - tool = TestTool(None) - message = '''Sure, I can help you find articles that are related to your manuscript's topic. This might provide us with authors who have recently published in this field, and they could be potential reviewers for your manuscript. Let's perform a search on Google Scholar to find such articles. - -Please wait a moment while I perform the search. - -{ "search_google_scholar": [ "Open-Access Database of Video Stimuli for Action Observation", "Psychometric Evaluation in Neuroimaging", "Motion Characterization in Neuroimaging Settings" ] }''' - target_message, results, needs_reply = tool.run(message) - assert target_message.strip() == '''Sure, I can help you find articles that are related to your manuscript's topic. This might provide us with authors who have recently published in this field, and they could be potential reviewers for your manuscript. Let's perform a search on Google Scholar to find such articles. - -Please wait a moment while I perform the search.''' - message = '''Sure, I can help with that. I'll perform a search on Google Scholar for articles related to pupil size. Just a moment, please. - -```json -{ - "search_google_scholar": [ - "pupil size" - ] -} -``` -''' - target_message, results, needs_reply = tool.run(message) - assert target_message.strip() == '''Sure, I can help with that. I'll perform a search on Google Scholar for articles related to pupil size. Just a moment, please.'''