From 456960c3302f42c0df8c23604bf38292dc2c8677 Mon Sep 17 00:00:00 2001 From: smathot Date: Fri, 12 Apr 2024 21:41:33 +0200 Subject: [PATCH] Refactor tool use - Anthropic needs to be implemented - Mistral tends to repeat tool uses - OpenAI works well --- heymans/__init__.py | 2 +- heymans/config.py | 15 +- heymans/database/manager.py | 1 + heymans/heymans.py | 98 +++++----- heymans/messages.py | 39 ++-- heymans/model.py | 173 ------------------ heymans/model/__init__.py | 30 +++ heymans/model/_anthropic_model.py | 28 +++ heymans/model/_base_model.py | 91 +++++++++ heymans/model/_mistral_model.py | 42 +++++ heymans/model/_openai_model.py | 89 +++++++++ heymans/prompt.py | 25 +-- heymans/routes/app.py | 2 +- heymans/templates/stylesheet.css.jinja | 11 ++ heymans/tools/__init__.py | 13 +- heymans/tools/_base_tool.py | 42 +++++ .../tools/{download_tool.py => _download.py} | 24 +-- ...ode_execution_tool.py => _execute_code.py} | 44 +++-- heymans/tools/_read_attachment.py | 55 ++++++ heymans/tools/_search_documentation.py | 64 +++++++ heymans/tools/_search_google_scholar.py | 39 ++++ heymans/tools/attachments_tool.py | 50 ----- heymans/tools/base_tool.py | 99 ---------- heymans/tools/google_scholar_tool.py | 37 ---- heymans/tools/search_tool.py | 18 -- tests/expensive/expensive_test_utils.py | 1 + ...attachments.py => test_tool_attachment.py} | 1 + ...ools_download.py => test_tool_download.py} | 0 tests/expensive/test_tool_execute_code.py | 10 + .../test_tool_search_documentation.py | 16 ++ ....py => test_tool_search_google_scholar.py} | 6 +- 31 files changed, 633 insertions(+), 532 deletions(-) delete mode 100644 heymans/model.py create mode 100644 heymans/model/__init__.py create mode 100644 heymans/model/_anthropic_model.py create mode 100644 heymans/model/_base_model.py create mode 100644 heymans/model/_mistral_model.py create mode 100644 heymans/model/_openai_model.py create mode 100644 heymans/tools/_base_tool.py rename heymans/tools/{download_tool.py => _download.py} (84%) rename heymans/tools/{code_execution_tool.py => _execute_code.py} (61%) create mode 100644 heymans/tools/_read_attachment.py create mode 100644 heymans/tools/_search_documentation.py create mode 100644 heymans/tools/_search_google_scholar.py delete mode 100644 heymans/tools/attachments_tool.py delete mode 100644 heymans/tools/base_tool.py delete mode 100644 heymans/tools/google_scholar_tool.py delete mode 100644 heymans/tools/search_tool.py rename tests/expensive/{test_tools_attachments.py => test_tool_attachment.py} (96%) rename tests/expensive/{test_tools_download.py => test_tool_download.py} (100%) create mode 100644 tests/expensive/test_tool_execute_code.py create mode 100644 tests/expensive/test_tool_search_documentation.py rename tests/expensive/{test_tools_google_scholar.py => test_tool_search_google_scholar.py} (71%) diff --git a/heymans/__init__.py b/heymans/__init__.py index 097c67e..f95aa97 100644 --- a/heymans/__init__.py +++ b/heymans/__init__.py @@ -1,3 +1,3 @@ """AI-based chatbot that provides sensible answers based on documentation""" -__version__ = '0.13.16' +__version__ = '0.14.0' diff --git a/heymans/config.py b/heymans/config.py index 12b3ea3..6b95165 100644 --- a/heymans/config.py +++ b/heymans/config.py @@ -71,10 +71,9 @@ ''' # The default title of a new conversation default_conversation_title = 'New conversation' -# The number of previous messages for which transient content should be -# retained. Transient content are large chunks of information that are included -# in AI messages, usually as the result of tool use. -keep_transient = 4 +# The number of previous messages for which tool results should be +# retained. +keep_tool_results = 4 # RATE LIMITS # @@ -125,7 +124,7 @@ 'answer_model': 'claude-3-opus' }, 'mistral': { - 'search_model': 'mistral-medium', + 'search_model': 'mistral-large', 'condense_model': 'mistral-medium', 'answer_model': 'mistral-large' }, @@ -140,11 +139,11 @@ # # Tools should match the names of classes from heymans.tools # Search tools are executed in the first documentation-search phase -search_tools = ['TopicsTool', 'SearchTool'] +search_tools = ['search_documentation'] # Answer tools are executed during the answer phase answer_tools_with_search = [] -answer_tools_without_search = ['CodeExecutionTool', 'GoogleScholarTool', - 'AttachmentsTool', 'DownloadTool'] +answer_tools_without_search = ['read_attachment', 'search_google_scholar', + 'execute_code', 'download'] # SETTINGS # diff --git a/heymans/database/manager.py b/heymans/database/manager.py index 9f5f547..5ac55a9 100644 --- a/heymans/database/manager.py +++ b/heymans/database/manager.py @@ -188,6 +188,7 @@ def add_attachment(self, attachment_data: dict) -> int: db.session.commit() return attachment.attachment_id except Exception as e: + breakpoint() logger.error(f"Error adding attachment: {e}") return -1 diff --git a/heymans/heymans.py b/heymans/heymans.py index 9deb35e..6d9f6bb 100644 --- a/heymans/heymans.py +++ b/heymans/heymans.py @@ -49,9 +49,6 @@ def __init__(self, user_id: str, persistent: bool = False, ] self.documentation = Documentation( self, sources=[FAISSDocumentationSource(self)]) - self.search_model = model(self, self.model_config['search_model']) - self.answer_model = model(self, self.model_config['answer_model']) - self.condense_model = model(self, self.model_config['condense_model']) self.messages = Messages(self, persistent) if search_tools is None: search_tools = config.search_tools @@ -64,7 +61,23 @@ def __init__(self, user_id: str, persistent: bool = False, # instantiated with heymans (self) as first argument self.search_tools = [getattr(tools, t)(self) for t in search_tools] self.answer_tools = [getattr(tools, t)(self) for t in answer_tools] - self.tools = self.answer_tools + # If there are search tools, the first one should always be used + if search_tools: + search_tool_choice = search_tools[0] + else: + search_tool_choice = None + # If there are answer tools, the mode can choose freely + if answer_tools: + answer_tool_choice = 'auto' + else: + answer_tool_choice = None + self.search_model = model(self, self.model_config['search_model'], + tools=self.search_tools, + tool_choice=search_tool_choice) + self.answer_model = model(self, self.model_config['answer_model'], + tools=self.answer_tools, + tool_choice=answer_tool_choice) + self.condense_model = model(self, self.model_config['condense_model']) def send_user_message(self, message: str, message_id: str=None) -> GeneratorType: @@ -105,12 +118,14 @@ def _search(self, message: str) -> GeneratorType: self.documentation.search([message]) # Then search based on the search-model queries derived from the user # question - self.tools = self.search_tools reply = self.search_model.predict(self.messages.prompt( system_prompt=prompt.SYSTEM_PROMPT_SEARCH)) if config.log_replies: logger.info(f'[search state] reply: {reply}') - self._run_tools(reply) + if callable(reply): + reply() + else: + logger.warning(f'[search state] did not call search tool') self.documentation.strip_irrelevant(message) logger.info( f'[search state] {len(self.documentation._documents)} documents, {len(self.documentation)} characters') @@ -120,7 +135,6 @@ def _answer(self, state: str = 'answer') -> GeneratorType: yield {'action': 'set_loading_indicator', 'message': f'{config.ai_name} is thinking and typing '}, {} logger.info(f'[{state} state] entering') - self.tools = self.answer_tools # We first collect a regular reply to the user message. While doing so # we also keep track of the number of tokens consumed. tokens_consumed_before = self.answer_model.total_tokens_consumed @@ -131,53 +145,37 @@ def _answer(self, state: str = 'answer') -> GeneratorType: self.database.add_activity(tokens_consumed) if config.log_replies: logger.info(f'[{state} state] reply: {reply}') - # We then run tools based on the AI reply. This may modify the reply, - # mainly by stripping out any JSON commands in the reply - reply, result, needs_feedback = self._run_tools(reply) - if needs_feedback: - logger.info(f'[{state} state] tools need feedback') - # If the reply contains a NOT_DONE_YET marker, this is a way for the AI - # to indicate that it wants to perform additional actions. This makes - # it easier to perform tasks consisting of multiple responses and - # actions. The marker is stripped from the reply so that it's hidden - # from the user. We also check for a number of common linguistic - # indicators that the AI isn't done yet, such "I will now". This is - # necessary because the explicit marker isn't reliably sent. - if self.answer_model.supports_not_done_yet and \ - prompt.NOT_DONE_YET_MARKER in reply: - reply = reply.replace(prompt.NOT_DONE_YET_MARKER, '') - logger.info(f'[{state} state] not-done-yet marker received') - needs_feedback = True - # If there is still a non-empty reply after running the tools (i.e. - # stripping the JSON hasn't cleared the reply entirely, then yield and - # remember it. - if reply: + # If the reply is a callable, then it's a tool that we need to run + if callable(reply): + tool_message, result, needs_feedback = reply() + if needs_feedback: + logger.info(f'[{state} state] tools need feedback') + metadata = self.messages.append('assistant', tool_message) + yield tool_message, metadata + # If the tool has a result, yield and remember it + if result: + metadata = self.messages.append('tool', result) + yield result, metadata + # Otherwise the reply is a regular AI message + else: metadata = self.messages.append('assistant', reply) yield reply, metadata - else: - metadata = self.messages.metadata() - # If the tools have a result, yield and remember it - if result: - self.messages.append('assistant', result) - yield result, metadata + # If the reply contains a NOT_DONE_YET marker, this is a way for the AI + # to indicate that it wants to perform additional actions. This makes + # it easier to perform tasks consisting of multiple responses and + # actions. The marker is stripped from the reply so that it's hidden + # from the user. We also check for a number of common linguistic + # indicators that the AI isn't done yet, such "I will now". This is + # necessary because the explicit marker isn't reliably sent. + if self.answer_model.supports_not_done_yet and \ + prompt.NOT_DONE_YET_MARKER in reply: + reply = reply.replace(prompt.NOT_DONE_YET_MARKER, '') + logger.info(f'[{state} state] not-done-yet marker received') + needs_feedback = True + else: + needs_feedback = False # If feedback is required, either because the tools require it or # because the AI sent a NOT_DONE_YET marker, go for another round. if needs_feedback and not self._rate_limit_exceeded(): for reply, metadata in self._answer(state='feedback'): yield reply, metadata - - def _run_tools(self, reply: str) -> Tuple[str, str, bool]: - """Runs all tools on a reply. Returns the modified reply, a string - that concatenates all output (an empty string if no output was - produced) and a bool indicating whether the AI should in turn repond - to the produced output. - """ - logger.info(f'running tools') - results = [] - needs_reply = [] - for tool in self.tools: - reply, tool_results, tool_needs_reply = tool.run(reply) - if tool_results: - results += tool_results - needs_reply.append(tool_needs_reply) - return reply, '\n\n'.join(results), any(needs_reply) diff --git a/heymans/messages.py b/heymans/messages.py index ac2d431..f3837db 100644 --- a/heymans/messages.py +++ b/heymans/messages.py @@ -8,15 +8,13 @@ from cryptography.fernet import InvalidToken from .model import model from . import prompt, config, utils, attachments -from langchain.schema import HumanMessage, AIMessage, SystemMessage +from langchain.schema import HumanMessage, AIMessage, SystemMessage, \ + FunctionMessage logger = logging.getLogger('heymans') class Messages: - regex_transient = re.compile(r"
.*?
", - re.DOTALL) - def __init__(self, heymans, persistent=False): self._heymans = heymans self._persistent = persistent @@ -85,15 +83,9 @@ def delete(self, message_id): if self._persistent: self.save() - def _message_is_transient(self, content): - return self.regex_transient.search(content) - def prompt(self, system_prompt=None): """The prompt consists of the system prompt followed by a sequence of - AI and user messages. Transient messages are special messages that are - hidden except when they are the last message. This allows the AI to - feed some information to itself to respond to without confounding the - rest of the conversation. + AI, user, and tool/ function messages. If no system prompt is provided, one is automatically constructed. Typically, an explicit system_prompt is provided during the search @@ -105,16 +97,15 @@ def prompt(self, system_prompt=None): msg_len = len(self._condensed_message_history) for msg_nr, (role, content) in enumerate( self._condensed_message_history): - # Messages may contain transient content, such as attachment text, - # which are removed if they are a few messages away in the history. - # This avoid the prompt from becoming too large. - if msg_nr + config.keep_transient < msg_len: - if self._message_is_transient(content): - content = '' if role == 'assistant': model_prompt.append(AIMessage(content=content)) elif role == 'user': model_prompt.append(HumanMessage(content=content)) + elif role == 'tool': + if msg_nr + config.keep_tool_results < msg_len: + continue + model_prompt.append(FunctionMessage(content=content, + name='tool_function')) else: raise ValueError(f'Invalid role: {role}') return model_prompt @@ -129,7 +120,7 @@ def _condense_message_history(self): messages = [{"role": "system", "content": system_prompt}] prompt_length = sum(len(content) for role, content in self._condensed_message_history - if not self._message_is_transient(content)) + if role != 'tool') logger.info(f'system prompt length: {len(system_prompt)}') logger.info(f'prompt length (without system prompt): {prompt_length}') if prompt_length <= config.max_prompt_length: @@ -169,16 +160,13 @@ def _system_prompt(self): # For models that support this, there is also an instruction indicating # that a special marker can be sent to indicate that the response isn't # done yet. Not all models support this to avoid infinite loops. - if self._heymans.answer_model.supports_not_done_yet and \ - self._heymans.tools: + if self._heymans.answer_model.supports_not_done_yet: system_prompt.append(prompt.SYSTEM_PROMPT_NOT_DONE_YET) - # Each tool has an explanation - for tool in self._heymans.tools: - if tool.prompt: - system_prompt.append(tool.prompt) # If available, documentation is also included in the prompt if len(self._heymans.documentation): system_prompt.append(self._heymans.documentation.prompt()) + system_prompt.append( + attachments.attachments_prompt(self._heymans.database)) # And finally, if the message history has been condensed, this is also # included. if self._condensed_text: @@ -186,7 +174,8 @@ def _system_prompt(self): system_prompt.append(prompt.render( prompt.SYSTEM_PROMPT_CONDENSED, summary=self._condensed_text)) - return '\n\n'.join(system_prompt) + # Combine all non-empty prompt chunks + return '\n\n'.join(chunk for chunk in system_prompt if chunk.strip()) def _update_title(self): """The conversation title is updated when there are at least two diff --git a/heymans/model.py b/heymans/model.py deleted file mode 100644 index 9ccc3e6..0000000 --- a/heymans/model.py +++ /dev/null @@ -1,173 +0,0 @@ -from . import config, utils -import re -import json -import logging -import asyncio -from langchain_community.callbacks import get_openai_callback -import time -logger = logging.getLogger('heymans') - - -class BaseModel: - - supports_not_done_yet = False - characters_per_token = 4 - - def __init__(self, heymans): - self._heymans = heymans - self.total_tokens_consumed = 0 - self.prompt_tokens_consumed = 0 - self.completion_tokens_consumed = 0 - - def predict(self, messages, track_tokens=True): - t0 = time.time() - logger.info(f'predicting with {self.__class__} model') - if isinstance(messages, str): - prompt_tokens = len(messages) // self.characters_per_token - reply = self._model.invoke(messages).content - dt = time.time() - t0 - logger.info(f'predicting {len(reply) + len(messages)} took {dt} s') - else: - reply = self._model.invoke(messages).content - dt = time.time() - t0 - msg_len = sum([len(m.content) for m in messages]) - prompt_tokens = msg_len // self.characters_per_token - logger.info(f'predicting {len(reply) + msg_len} took {dt} s') - if track_tokens: - completion_tokens = len(reply) // self.characters_per_token - total_tokens = prompt_tokens + completion_tokens - self.total_tokens_consumed += total_tokens - self.prompt_tokens_consumed += prompt_tokens - self.completion_tokens_consumed += completion_tokens - logger.info(f'total tokens (approx.): {total_tokens}') - logger.info(f'prompt tokens (approx.): {prompt_tokens}') - logger.info(f'completion tokens (approx.): {completion_tokens}') - return reply - - def predict_multiple(self, prompts): - """Predicts multiple simple (non-message history) prompts using asyncio - if possible. - """ - try: - loop = asyncio.get_event_loop() - if not loop.is_running(): - logger.info('re-using async event loop') - use_async = True - else: - logger.info('async event loop is already running') - use_async = False - except RuntimeError as e: - logger.info('creating async event loop') - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - use_async = True - - if not use_async: - logger.info('predicting multiple without async') - return [self._model.invoke(prompt).content for prompt in prompts] - - async def wrap_gather(): - tasks = [self._model.ainvoke(prompt) for prompt in prompts] - predictions = await asyncio.gather(*tasks) - return [p.content for p in predictions] - - logger.info('predicting multiple using async') - return loop.run_until_complete(wrap_gather()) - - -class OpenAIModel(BaseModel): - - supports_not_done_yet = True - - def __init__(self, heymans, model): - from langchain_openai.chat_models import ChatOpenAI - super().__init__(heymans) - self._model = ChatOpenAI( - model=model, - openai_api_key=config.openai_api_key) - - def predict(self, messages): - with get_openai_callback() as cb: - retval = super().predict(messages, track_tokens=False) - logger.info(cb) - self.total_tokens_consumed += cb.total_tokens - self.prompt_tokens_consumed += cb.prompt_tokens - self.completion_tokens_consumed += cb.completion_tokens - return retval - - def predict_multiple(self, prompts): - with get_openai_callback() as cb: - retval = super().predict_multiple(prompts) - logger.info(cb) - self.total_tokens_consumed += cb.total_tokens - self.prompt_tokens_consumed += cb.prompt_tokens - self.completion_tokens_consumed += cb.completion_tokens - return retval - - -class ClaudeModel(BaseModel): - - max_retry = 3 - - def __init__(self, heymans, model): - from langchain_anthropic import ChatAnthropic - super().__init__(heymans) - self._model = ChatAnthropic( - model=model, anthropic_api_key=config.anthropic_api_key) - - def predict(self, messages): - if isinstance(messages, list): - messages = utils.prepare_messages(messages, allow_ai_first=False, - allow_ai_last=False, - merge_consecutive=True) - # Claude seems to crash occasionally, in which case a retry will do the - # trick - for i in range(self.max_retry): - try: - return super().predict(messages) - except Exception as e: - logger.warning(f'error in prediction (retrying): {e}') - return super().predict(messages) - - -class MistralModel(BaseModel): - - def __init__(self, heymans, model): - from langchain_mistralai.chat_models import ChatMistralAI - super().__init__(heymans) - self._model = ChatMistralAI( - model=model, - openai_api_key=config.mistral_api_key) - - def predict(self, messages): - if isinstance(messages, list): - messages = utils.prepare_messages(messages, allow_ai_first=False, - allow_ai_last=False, - merge_consecutive=True) - return super().predict(messages) - - -class DummyModel(BaseModel): - def predict(self, messages): - return 'dummy reply' - - -def model(heymans, model): - - if model == 'gpt-4': - return OpenAIModel(heymans, 'gpt-4-1106-preview') - if model == 'gpt-3.5': - return OpenAIModel(heymans, 'gpt-3.5-turbo-1106') - if model == 'claude-2.1': - return ClaudeModel(heymans, 'claude-2.1') - if model == 'claude-3-opus': - return ClaudeModel(heymans, 'claude-3-opus-20240229') - if model == 'claude-3-sonnet': - return ClaudeModel(heymans, 'claude-3-sonnet-20240229') - if model.startswith('mistral-'): - if not model.endswith('-latest'): - model += '-latest' - return MistralModel(heymans, model) - if model == 'dummy': - return DummyModel(heymans) - raise ValueError(f'Unknown model: {model}') diff --git a/heymans/model/__init__.py b/heymans/model/__init__.py new file mode 100644 index 0000000..b2d34d1 --- /dev/null +++ b/heymans/model/__init__.py @@ -0,0 +1,30 @@ +from ._base_model import BaseModel +from ._openai_model import OpenAIModel +from ._mistral_model import MistralModel +from ._anthropic_model import AnthropicModel + + +class DummyModel(BaseModel): + def predict(self, messages): + return 'dummy reply' + + +def model(heymans, model, **kwargs): + + if model == 'gpt-4': + return OpenAIModel(heymans, 'gpt-4-1106-preview', **kwargs) + if model == 'gpt-3.5': + return OpenAIModel(heymans, 'gpt-3.5-turbo-1106', **kwargs) + if model == 'claude-2.1': + return AnthropicModel(heymans, 'claude-2.1', **kwargs) + if model == 'claude-3-opus': + return AnthropicModel(heymans, 'claude-3-opus-20240229', **kwargs) + if model == 'claude-3-sonnet': + return AnthropicModel(heymans, 'claude-3-sonnet-20240229', **kwargs) + if model.startswith('mistral-'): + if not model.endswith('-latest'): + model += '-latest' + return MistralModel(heymans, model, **kwargs) + if model == 'dummy': + return DummyModel(heymans, **kwargs) + raise ValueError(f'Unknown model: {model}') diff --git a/heymans/model/_anthropic_model.py b/heymans/model/_anthropic_model.py new file mode 100644 index 0000000..e8f9cd8 --- /dev/null +++ b/heymans/model/_anthropic_model.py @@ -0,0 +1,28 @@ +from . import BaseModel +from .. import config + + +class AnthropicModel(BaseModel): + + max_retry = 3 + + def __init__(self, heymans, model): + from langchain_anthropic import ChatAnthropic + super().__init__(heymans) + self._model = ChatAnthropic( + model=model, anthropic_api_key=config.anthropic_api_key) + + def predict(self, messages): + if isinstance(messages, list): + messages = utils.prepare_messages(messages, allow_ai_first=False, + allow_ai_last=False, + merge_consecutive=True) + # Claude seems to crash occasionally, in which case a retry will do the + # trick + for i in range(self.max_retry): + try: + return super().predict(messages) + except Exception as e: + logger.warning(f'error in prediction (retrying): {e}') + return super().predict(messages) + diff --git a/heymans/model/_base_model.py b/heymans/model/_base_model.py new file mode 100644 index 0000000..f7f875e --- /dev/null +++ b/heymans/model/_base_model.py @@ -0,0 +1,91 @@ +import logging +import asyncio +import time +logger = logging.getLogger('heymans') + + +class BaseModel: + + supports_not_done_yet = False + characters_per_token = 4 + + def __init__(self, heymans, tools=None, tool_choice='auto'): + self._heymans = heymans + self._tools = tools + self._tool_choice = tool_choice + self.total_tokens_consumed = 0 + self.prompt_tokens_consumed = 0 + self.completion_tokens_consumed = 0 + + def invalid_tool(self): + return 'Invalid tool' + + def get_response(self, response) -> [str, callable]: + return response.content + + def tools(self): + return [{"type": "function", "function": t.tool_spec} + for t in self._tools if t.tool_spec] + + def invoke(self, messages): + return self._model.invoke(messages) + + def async_invoke(self, messages): + return self._model.ainvoke(messages) + + def messages_length(self, messages): + if isinstance(messages, str): + return len(messages) + return sum([len(m.content if hasattr(m, 'content') else m['content']) + for m in messages]) + + def predict(self, messages, track_tokens=True): + t0 = time.time() + logger.info(f'predicting with {self.__class__} model') + reply = self.get_response(self.invoke(messages)) + msg_len = self.messages_length(messages) + dt = time.time() - t0 + prompt_tokens = msg_len // self.characters_per_token + reply_len = len(reply) if isinstance(reply, str) else 0 + logger.info(f'predicting {reply_len + msg_len} took {dt} s') + if track_tokens: + completion_tokens = reply_len // self.characters_per_token + total_tokens = prompt_tokens + completion_tokens + self.total_tokens_consumed += total_tokens + self.prompt_tokens_consumed += prompt_tokens + self.completion_tokens_consumed += completion_tokens + logger.info(f'total tokens (approx.): {total_tokens}') + logger.info(f'prompt tokens (approx.): {prompt_tokens}') + logger.info(f'completion tokens (approx.): {completion_tokens}') + return reply + + def predict_multiple(self, prompts): + """Predicts multiple simple (non-message history) prompts using asyncio + if possible. + """ + try: + loop = asyncio.get_event_loop() + if not loop.is_running(): + logger.info('re-using async event loop') + use_async = True + else: + logger.info('async event loop is already running') + use_async = False + except RuntimeError as e: + logger.info('creating async event loop') + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + use_async = True + + if not use_async: + logger.info('predicting multiple without async') + return [self.get_response(self.invoke(prompt)) + for prompt in prompts] + + async def wrap_gather(): + tasks = [self.async_invoke(prompt) for prompt in prompts] + predictions = await asyncio.gather(*tasks) + return [self.get_response(p) for p in predictions] + + logger.info('predicting multiple using async') + return loop.run_until_complete(wrap_gather()) diff --git a/heymans/model/_mistral_model.py b/heymans/model/_mistral_model.py new file mode 100644 index 0000000..5a2e7f2 --- /dev/null +++ b/heymans/model/_mistral_model.py @@ -0,0 +1,42 @@ +from .. import config, utils +from . import OpenAIModel, BaseModel +from langchain.schema import SystemMessage, AIMessage, HumanMessage, \ + FunctionMessage + + +class MistralModel(OpenAIModel): + + supports_not_done_yet = False + + def __init__(self, heymans, model, **kwargs): + from mistralai.async_client import MistralAsyncClient + from mistralai.client import MistralClient + BaseModel.__init__(self, heymans, **kwargs) + self._model = model + if self._tool_choice is not None: + self._tool_choice = 'any' + self._client = MistralClient(api_key=config.mistral_api_key) + self._async_client = MistralAsyncClient(api_key=config.mistral_api_key) + + def convert_message(self, message): + from mistralai.models.chat_completion import ChatMessage + message = super().convert_message(message) + return ChatMessage(**message) + + def predict(self, messages): + if isinstance(messages, str): + messages = [self.convert_message(messages)] + else: + messages = utils.prepare_messages(messages, allow_ai_first=False, + allow_ai_last=False, + merge_consecutive=True) + messages = [self.convert_message(message) for message in messages] + return BaseModel.predict(self, messages) + + def invoke(self, messages): + return self._client.chat(model=self._model, messages=messages, + **self._tool_args()) + + def async_invoke(self, messages): + return self._async_client.chat(model=self._model, messages=messages, + **self._tool_args()) diff --git a/heymans/model/_openai_model.py b/heymans/model/_openai_model.py new file mode 100644 index 0000000..3717615 --- /dev/null +++ b/heymans/model/_openai_model.py @@ -0,0 +1,89 @@ +from .. import config +from . import BaseModel +from langchain.schema import SystemMessage, AIMessage, HumanMessage, \ + FunctionMessage + + +class OpenAIModel(BaseModel): + + supports_not_done_yet = True + + def __init__(self, heymans, model, **kwargs): + from openai import Client, AsyncClient + super().__init__(heymans, **kwargs) + self._model = model + if self._tool_choice not in (None, 'auto'): + self._tool_choice = {"type": "function", + "function": {"name": self._tool_choice}} + self._client = Client(api_key=config.openai_api_key) + self._async_client = AsyncClient(api_key=config.openai_api_key) + + def convert_message(self, message): + # OpenAI expects messages as dict objects + if isinstance(message, str): + return dict(role='user', content=message) + if isinstance(message, SystemMessage): + role = 'system' + elif isinstance(message, AIMessage): + role = 'assistant' + elif isinstance(message, HumanMessage): + role = 'user' + elif isinstance(message, FunctionMessage): + role = 'tool' + else: + raise ValueError(f'Unknown message type: {message}') + return dict(role=role, content=message.content) + + def predict(self, messages): + # Strings need to be converted a list of length one with a single + # message dict + if isinstance(messages, str): + messages = [self.convert_message(messages)] + else: + messages = [self.convert_message(message) for message in messages] + # OpenAI requires the tool message to be linked to the previous AI + # message with a tool_call_id. The actual content doesn't appear to + # matter, so here we dummy-link the messages + for i, message in enumerate(messages): + if i == 0 or message['role'] != 'tool': + continue + tool_call_id = f'call_{i}' + prev_message = messages[i - 1] + prev_message['tool_calls'] = [ + { + 'id': tool_call_id, + 'type': 'function', + 'function': { + 'name': 'dummy', + 'arguments': '' + } + }] + message['tool_call_id'] = tool_call_id + return super().predict(messages) + + def predict_multiple(self, prompts): + prompts = [[self.convert_message(prompt)] for prompt in prompts] + return super().predict_multiple(prompts) + + def get_response(self, response): + tool_calls = response.choices[0].message.tool_calls + if tool_calls: + function = tool_calls[0].function + for tool in self._tools: + if tool.name == function.name: + return tool.bind(function.arguments) + return self.invalid_tool + return response.choices[0].message.content + + def _tool_args(self): + if not self._tools: + return {} + return {'tools': self.tools(), 'tool_choice': self._tool_choice} + + def invoke(self, messages): + return self._client.chat.completions.create( + model=self._model, messages=messages, **self._tool_args()) + + def async_invoke(self, messages): + return self._async_client.chat.completions.create( + model=self._model, messages=messages, **self._tool_args()) diff --git a/heymans/prompt.py b/heymans/prompt.py index 6c27e7e..6d0b395 100644 --- a/heymans/prompt.py +++ b/heymans/prompt.py @@ -1,30 +1,7 @@ import jinja2 # The system prompt during documentation search consists of the prompt below -SYSTEM_PROMPT_SEARCH = '''You are Sigmund, an assistant for users of OpenSesame, a program for building psychology and neuroscience experiments. - -Do not answer the user's question. Instead, request documentation by replying with JSON in the format shown below. Use the "topics" field to indicate which topics are related to the question. Only use topics shown in the example. Do not make up your own topics. Use the "search" field to specify additional search queries that you feel are relevant. - -{ - "topics": [ - "opensesame", - "osweb", - "python", - "javascript", - "inline_script", - "inline_javascript", - "datamatrix", - "data_analysis", - "questions_howto" - ], - "search", [ - "search query 1", - "search query 2" - ] -} - -Respond only with JSON. Do not include additional text in your reply. -''' +SYSTEM_PROMPT_SEARCH = '''Do not answer the user's question. Instead, use the search_documentation function tool to search for relevant documentation.''' # The system prompt used during question answering is composed of the fragments # below diff --git a/heymans/routes/app.py b/heymans/routes/app.py index a845207..964ddb6 100644 --- a/heymans/routes/app.py +++ b/heymans/routes/app.py @@ -38,7 +38,7 @@ def chat_page(): for role, message, metadata in heymans.messages: message_id = metadata.get('message_id', 0) delete_button = f'' - if role == 'assistant': + if role in ('assistant', 'tool'): html_body = utils.md( f'{config.ai_name}: {config.process_ai_message(message)}') html_class = 'message-ai' diff --git a/heymans/templates/stylesheet.css.jinja b/heymans/templates/stylesheet.css.jinja index a9cea60..2ccbb75 100644 --- a/heymans/templates/stylesheet.css.jinja +++ b/heymans/templates/stylesheet.css.jinja @@ -377,3 +377,14 @@ button#start { float: right; margin: 0px 0px 20px 20px!important; } + +.google-scholar-search-results, +.attachment-content { + margin-bottom: 10px; + border-radius: 4px; + padding: 10px; + max-height: 100px; + font-size: 0.7em; + overflow-y: auto; + background-color: #cfd8dc; +} diff --git a/heymans/tools/__init__.py b/heymans/tools/__init__.py index af75834..334880e 100644 --- a/heymans/tools/__init__.py +++ b/heymans/tools/__init__.py @@ -1,7 +1,6 @@ -from .base_tool import BaseTool -from .search_tool import SearchTool -from .topics_tool import TopicsTool -from .code_execution_tool import CodeExecutionTool -from .google_scholar_tool import GoogleScholarTool -from .attachments_tool import AttachmentsTool -from .download_tool import DownloadTool +from ._base_tool import BaseTool +from ._search_documentation import search_documentation +from ._search_google_scholar import search_google_scholar +from ._read_attachment import read_attachment +from ._execute_code import execute_code +from ._download import download diff --git a/heymans/tools/_base_tool.py b/heymans/tools/_base_tool.py new file mode 100644 index 0000000..f795d44 --- /dev/null +++ b/heymans/tools/_base_tool.py @@ -0,0 +1,42 @@ +import logging +# import re +import json +import functools +from typing import Optional, Tuple +logger = logging.getLogger('heymans') + + +class BaseTool: + """A base class for tools that process an AI reply.""" + + tool_spec = None + + def __init__(self, heymans): + self._heymans = heymans + + @property + def tool_spec(self): + return { + "name": self.__class__.__name__, + "description": self.__doc__, + "parameters": { + "type": "object", + "properties": self.arguments, + "required": self.required_arguments, + } + } + + @property + def name(self): + return self.__class__.__name__ + + def bind(self, args): + if isinstance(args, str): + args = json.loads(args) + return functools.partial(self, **args) + + def __call__(self) -> Tuple[str, Optional[str], bool]: + """Should be implemented in a tool with additional arguments that + match tool specification. + """ + raise NotImplementedError() diff --git a/heymans/tools/download_tool.py b/heymans/tools/_download.py similarity index 84% rename from heymans/tools/download_tool.py rename to heymans/tools/_download.py index c4e21e4..1cdd8d4 100644 --- a/heymans/tools/download_tool.py +++ b/heymans/tools/_download.py @@ -9,17 +9,16 @@ logger = logging.getLogger('heymans') -class DownloadTool(BaseTool): +class download(BaseTool): + """Download files or webpage from the internet and save them as attachments""" - # The JSON pattern should match the regular expression shown in the prompt - # and catch the URL as a group with the name 'url'. - json_pattern = r"\"download_url\":\s*\"(?Phttps?://[^']+)\"" - prompt = '''# Download files - -You have access to the internet. To download a file, use JSON in the format below. The file will be added to your attachments. - -{"download_url": "https://url_to_file"} -''' + arguments = { + "url": { + "type": "string", + "description": "The url of a file of webpage", + } + } + required_arguments = ["url"] def _download(self, url): """Download URL and return a filename, content tuple.""" @@ -56,7 +55,7 @@ def _download(self, url): logger.error(f"Failed to download the file from {url}: {e}") raise - def use(self, message, url): + def __call__(self, url): try: filename, content = self._download(url) except Exception as e: @@ -70,4 +69,5 @@ def use(self, message, url): 'description': description } self._heymans.database.add_attachment(attachment_data) - return f'''I have downloaded {filename} and added it to my attachments.''', True + return f'''I have downloaded {filename} and added it to my attachments.''', \ + None, False diff --git a/heymans/tools/code_execution_tool.py b/heymans/tools/_execute_code.py similarity index 61% rename from heymans/tools/code_execution_tool.py rename to heymans/tools/_execute_code.py index 6c57252..3e56213 100644 --- a/heymans/tools/code_execution_tool.py +++ b/heymans/tools/_execute_code.py @@ -4,26 +4,23 @@ logger = logging.getLogger('heymans') -class CodeExecutionTool(BaseTool): +class execute_code(BaseTool): + """Execute Python and R code""" - json_pattern = r""" -\s*"execute_code"\s*:\s*\{ -\s*"language"\s*:\s*"(?P.+?)" -\s*,\s*"code"\s*:\s*"(?P.+?)" -\s*\} -""" - prompt = '''# Code execution - -You are also a brilliant programmer. To execute Python and R code, use JSON in the format below. You will receive the output in the next message. Example code included elsewhere in your reply will not be executed. Never execute OpenSesame inline scripts. - -{ - "execute_code": { - "language": "python", - "code": "print('your code here')" + arguments = { + 'language': { + 'type': 'string', + 'description': 'The programming language to use', + 'enum': ['r', 'python'] + }, + 'code': { + 'type': 'string', + 'description': 'The code to execute. Use print() to print to the standard output.' + } } -}''' + required_arguments = ['language', 'code'] - def use(self, message, language, code): + def __call__(self, language, code): logger.info(f'executing {language} code: {code}') url = "https://emkc.org/api/v2/piston/execute" language_aliases = {'python': 'python', @@ -45,19 +42,20 @@ def use(self, message, language, code): response = requests.post(url, json=data) if response.status_code == 200: response_data = response.json() - result = response_data.get("run", {}).get("output", "") + result = response_data.get("run", {}).get("output", "").strip() logger.info(f'result: {result}') - result_msg = f'''I executed the following code: + result = f'''I executed the following code: ```{language} {code} ``` -And got the following output: +And received the following output: ``` {result} -```''' - return result_msg, True +``` +''' + return 'Executing code ...', result, True logger.error(f"Error: {response.status_code} with message: {response.content}") - return 'Failed to execute code', True + return 'Failed to execute code', None, True diff --git a/heymans/tools/_read_attachment.py b/heymans/tools/_read_attachment.py new file mode 100644 index 0000000..3679ae3 --- /dev/null +++ b/heymans/tools/_read_attachment.py @@ -0,0 +1,55 @@ +from . import BaseTool +import logging +import json +import base64 +from .. import utils +from ..attachments import file_to_text +logger = logging.getLogger('heymans') + + +class read_attachment(BaseTool): + """Read an attached file""" + arguments = { + "filename": { + "type": "string", + "description": "The attachment file to read", + } + } + required_arguments = ["filename"] + + @property + def tool_spec(self): + arguments = self.arguments.copy() + arguments['filename']['enum'] = [ + attachment['filename'] for attachment in + self._heymans.database.list_attachments().values() + ] + return { + "name": self.__class__.__name__, + "description": self.__doc__, + "parameters": { + "type": "object", + "properties": arguments, + "required": self.required_arguments, + } + } + + def __call__(self, filename): + texts = [] + for attachment_id, attachment in \ + self._heymans.database.list_attachments().items(): + if filename != attachment['filename']: + continue + attachment = self._heymans.database.get_attachment(attachment_id) + content = file_to_text( + attachment['filename'], + base64.b64decode(attachment['content'])) + text = f'File name: {attachment["filename"]}\n\nFile content:\n{content}' + result = f'''One moment please ... + +
+{text} +
''' + return 'I am going to read the attached file now.', result, True + return 'Something went wrong while trying to read the attachment', \ + '', False diff --git a/heymans/tools/_search_documentation.py b/heymans/tools/_search_documentation.py new file mode 100644 index 0000000..dd8a746 --- /dev/null +++ b/heymans/tools/_search_documentation.py @@ -0,0 +1,64 @@ +from . import BaseTool +from .. import config +import logging +from pathlib import Path +from langchain_core.documents import Document +logger = logging.getLogger('heymans') + + +class search_documentation(BaseTool): + """Search the documentation based on topics and search queries""" + + arguments = { + "primary_topic": { + "type": "string", + "description": "The primary topic of the question" + }, + "secondary_topic": { + "type": "string", + "description": "The secondary topic of the question" + }, + "queries": { + "type": "array", + "items": { + "type": "string" + }, + "description": "A list of queries to search the documentation", + } + } + required_arguments = ['primary_topic', 'queries'] + + @property + def tool_spec(self): + topics = list(config.topic_sources.keys()) + arguments = self.arguments.copy() + arguments['primary_topic']['enum'] = topics + arguments['secondary_topic']['enum'] = topics + return { + "name": self.__class__.__name__, + "description": self.__doc__, + "parameters": { + "type": "object", + "properties": arguments, + "required": self.required_arguments, + } + } + + def __call__(self, primary_topic, queries, secondary_topic=None): + if len(self._heymans.documentation) == 0: + logger.info('no topics were added, so skipping search') + else: + self._heymans.documentation.search(queries) + topics = [primary_topic] + if secondary_topic: + topics.append(secondary_topic) + for topic in topics: + if topic not in config.topic_sources: + logger.warning(f'unknown topic: {topic}') + continue + logger.info(f'appending doc for topic: {topic}') + doc = Document( + page_content=Path(config.topic_sources[topic]).read_text()) + doc.metadata['important'] = True + self._heymans.documentation.append(doc) + return 'Searching documentation ...', None, False diff --git a/heymans/tools/_search_google_scholar.py b/heymans/tools/_search_google_scholar.py new file mode 100644 index 0000000..68325bd --- /dev/null +++ b/heymans/tools/_search_google_scholar.py @@ -0,0 +1,39 @@ +from . import BaseTool +import logging +import json +from scholarly import scholarly +logger = logging.getLogger('heymans') + + +class search_google_scholar(BaseTool): + """Search Google Scholar for scientific articles""" + + arguments = { + "queries": { + "type": "array", + "items": { + "type": "string" + }, + "description": "A list of search queries", + } + } + required_arguments = ["queries"] + + def __call__(self, queries): + results = [] + for query in queries: + for i, result in enumerate(scholarly.search_pubs(query)): + logger.info(f'appending doc for google scholar search: {query}') + info = result['bib'] + if 'eprint_url' in result: + info['fulltext_url'] = result['eprint_url'] + results.append(info) + if i >= 3: + break + results = f'''I found {len(results)} articles ... + +
+{json.dumps(results)} +
''' + return 'Searching for articles on Google Scholar ...', \ + results, True diff --git a/heymans/tools/attachments_tool.py b/heymans/tools/attachments_tool.py deleted file mode 100644 index fea6bdc..0000000 --- a/heymans/tools/attachments_tool.py +++ /dev/null @@ -1,50 +0,0 @@ -from . import BaseTool -import logging -import json -import base64 -from .. import utils -from ..attachments import file_to_text -logger = logging.getLogger('heymans') - - -class AttachmentsTool(BaseTool): - - json_pattern = r'"read_attachments"\s*:\s*(?P\[\s*"(?:[^"\\]|\\.)*"(?:\s*,\s*"(?:[^"\\]|\\.)*")*\s*\])' - - @property - def prompt(self): - info = {'read_attachments': []} - for attachment in self._heymans.database.list_attachments().values(): - info['read_attachments'].append(attachment['filename']) - if not info['read_attachments']: - logger.info('no attachments for tool') - return '' - info = json.dumps(info) - return f'''# Attachments - -To read attachments, use JSON in the format below. You will receive the attachments in the next message. - -{info} -''' - - def use(self, message, attachments): - - texts = [] - for attachment_id, attachment in \ - self._heymans.database.list_attachments().items(): - if attachment['filename'] not in attachments: - continue - attachment = self._heymans.database.get_attachment(attachment_id) - content = file_to_text( - attachment['filename'], - base64.b64decode(attachment['content'])) - text = f'File name: {attachment["filename"]}\n\nFile content:\n{content}' - texts.append(text) - if not texts: - return '', False - texts = '\n\n'.join(texts) - return f'''I am going to read the attached file(s) now. - -
-{texts} -
''', True diff --git a/heymans/tools/base_tool.py b/heymans/tools/base_tool.py deleted file mode 100644 index 2d81eb8..0000000 --- a/heymans/tools/base_tool.py +++ /dev/null @@ -1,99 +0,0 @@ -import logging -import re -import json -from typing import Optional, Tuple -logger = logging.getLogger('heymans') - - -class BaseTool: - """A base class for tools that process an AI reply.""" - - # A JSON patter to match tool instructions. Should contain named groups, - # which are passed to use() as named arguments - json_pattern = None - # A prompt section that is automatically included in the system prompt - # when specified. - prompt = None - - def __init__(self, heymans): - self.json_pattern = re.compile(self.json_pattern, - re.VERBOSE | re.DOTALL) - self._heymans = heymans - - def use(self, message: str) -> Tuple[Optional[str], bool]: - """Should be implemented in a tool with additional arguments that - match the names of the fields from the json_pattern. - """ - raise NotImplementedError() - - def run(self, message: str) -> Tuple[str, list, bool]: - """Takes a message and uses the tool if the messages contains relevant - JSON instructions. Returns the updated message, which can be changed by - the tool notably by string the tool JSON instructions, a list of result - strings, and a boolean indicating whether the model should reply to the - result. - - Since there can be multiple instructions for one tool in a single - message, the result is a list, rather than a single string. - """ - spans = [] - results = [] - needs_reply = [] - for match in self.json_pattern.finditer(message): - logger.info(f'executing tool {self}') - args = {self.as_json_value(key) : self.as_json_value(val) - for key, val in match.groupdict().items()} - match_result, match_needs_reply = self.use(message, **args) - if match_result is not None: - results.append(match_result) - needs_reply.append(match_needs_reply) - spans.append((match.start(), match.end())) - # We now loop through all spans that correspond to JSON code blocks. - # We find the first preceding { and succeeding } because the matching - # only occurs within a block, and then we remove this. The goal of this - # is to strip the JSON blocks from the reply. - if spans: - for span in spans[::-1]: - for start in range(span[0], -1, -1): - ch = message[start] - if ch == '{': - break - if start is not span[0] and not ch.isspace() and ch != '"': - start = None - break - for end in range(span[1], len(message) + 1): - ch = message[end] - if ch == '}': - break - if not ch.isspace() and ch != '"': - start = None - break - if start is not None and end is not None: - message = message[:start] + message[end + 1:] - # Remove empty JSON code blocks in case the JSON was embedded in - # blocks - message = re.sub(r'```json\s*```', '', message).strip() - if not message: - message = f'Running `{self.__class__.__name__}` … ''' - return message, results, any(needs_reply) - - def as_json_value(self, s): - orig_s = s - try: - return json.loads(s) - except json.JSONDecodeError: - try: - # Sometimes the JSON is broken by having backslashes in there, - # for example 'inline\_script'. We try to patch this here - return json.loads(s.replace('\\', '')) - except json.JSONDecodeError: - try: - # If this still doesn't work we treat the string as string - # rather than some other structure, in which case newlines - # need to be recoded to literal \n - s = s.replace('\n', r'\n') - return json.loads(f'"{s}"') - except json.JSONDecodeError: - # If this still doesn't work we consider the parsing failed - logger.warning(f'failed to parse JSON: {orig_s}') - return json.loads('"failed to parse JSON"') diff --git a/heymans/tools/google_scholar_tool.py b/heymans/tools/google_scholar_tool.py deleted file mode 100644 index 93c96bf..0000000 --- a/heymans/tools/google_scholar_tool.py +++ /dev/null @@ -1,37 +0,0 @@ -from . import BaseTool -import logging -import json -from scholarly import scholarly -logger = logging.getLogger('heymans') - - -class GoogleScholarTool(BaseTool): - - json_pattern = r'"search_google_scholar"\s*:\s*(?P\[\s*"(?:[^"\\]|\\.)*"(?:\s*,\s*"(?:[^"\\]|\\.)*")*\s*\])' - prompt = '''# Search Google Scholar - -You are also a brilliant researcher. To search for articles on Google Scholar, use JSON in the format below. You will receive the output in the next message. - -{ - "search_google_scholar": [ - "search query 1", - "search query 2" - ] -}''' - - def use(self, message, queries): - results = [] - for query in queries: - for i, result in enumerate(scholarly.search_pubs(query)): - logger.info(f'appending doc for google scholar search: {query}') - info = result['bib'] - if 'eprint_url' in result: - info['fulltext_url'] = result['eprint_url'] - results.append(info) - if i >= 3: - break - return f'''I found {len(results)} articles. I am going to read them now. - -
-{json.dumps(results)} -
''', True diff --git a/heymans/tools/search_tool.py b/heymans/tools/search_tool.py deleted file mode 100644 index 5d2718b..0000000 --- a/heymans/tools/search_tool.py +++ /dev/null @@ -1,18 +0,0 @@ -from . import BaseTool -import logging -logger = logging.getLogger('heymans') - - -class SearchTool(BaseTool): - """Searches through the indexed documentation and inserts matching - fragments into the documentation object. - """ - - json_pattern = r'"search"\s*:\s*(?P\[\s*"(?:[^"\\]|\\.)*"(?:\s*,\s*"(?:[^"\\]|\\.)*")*\s*\])' - - def use(self, message, queries): - if len(self._heymans.documentation) == 0: - logger.info('no topics were added, so skipping search') - else: - self._heymans.documentation.search(queries) - return None, False diff --git a/tests/expensive/expensive_test_utils.py b/tests/expensive/expensive_test_utils.py index 6776249..c451ddb 100644 --- a/tests/expensive/expensive_test_utils.py +++ b/tests/expensive/expensive_test_utils.py @@ -14,6 +14,7 @@ def setUp(self): init_db() self.heymans = Heymans(user_id='pytest', search_first=False) config.max_tokens_per_hour = float('inf') + config.log_replies = True def _test_tool(self): pass diff --git a/tests/expensive/test_tools_attachments.py b/tests/expensive/test_tool_attachment.py similarity index 96% rename from tests/expensive/test_tools_attachments.py rename to tests/expensive/test_tool_attachment.py index ce50bf0..ded5a22 100644 --- a/tests/expensive/test_tools_attachments.py +++ b/tests/expensive/test_tool_attachment.py @@ -14,6 +14,7 @@ def _test_tool(self): assert len(self.heymans.database.list_attachments()) == 1 query = 'Which artist name does the attachment contain?' for reply, metadata in self.heymans.send_user_message(query): + print(reply) if 'Rick Ross' in reply: break else: diff --git a/tests/expensive/test_tools_download.py b/tests/expensive/test_tool_download.py similarity index 100% rename from tests/expensive/test_tools_download.py rename to tests/expensive/test_tool_download.py diff --git a/tests/expensive/test_tool_execute_code.py b/tests/expensive/test_tool_execute_code.py new file mode 100644 index 0000000..5cc4153 --- /dev/null +++ b/tests/expensive/test_tool_execute_code.py @@ -0,0 +1,10 @@ +import base64 +from .expensive_test_utils import BaseExpensiveTest + + +class TestExecuteCode(BaseExpensiveTest): + + def _test_tool(self): + query = 'Can you calculate the square root of 7 using Python code?' + for reply, metadata in self.heymans.send_user_message(query): + print(reply) diff --git a/tests/expensive/test_tool_search_documentation.py b/tests/expensive/test_tool_search_documentation.py new file mode 100644 index 0000000..3840057 --- /dev/null +++ b/tests/expensive/test_tool_search_documentation.py @@ -0,0 +1,16 @@ +import base64 +from heymans.heymans import Heymans +from .expensive_test_utils import BaseExpensiveTest + + +class TestToolSearchDocumentation(BaseExpensiveTest): + + def setUp(self): + super().setUp() + self.heymans = Heymans(user_id='pytest', search_first=True) + + def _test_tool(self): + query = 'What is the first header line of the OpenSesame topic documentation?' + for reply, metadata in self.heymans.send_user_message(query): + print(reply) + assert 'important' in reply.lower() diff --git a/tests/expensive/test_tools_google_scholar.py b/tests/expensive/test_tool_search_google_scholar.py similarity index 71% rename from tests/expensive/test_tools_google_scholar.py rename to tests/expensive/test_tool_search_google_scholar.py index dbf62eb..b65d862 100644 --- a/tests/expensive/test_tools_google_scholar.py +++ b/tests/expensive/test_tool_search_google_scholar.py @@ -6,7 +6,5 @@ class TestToolsGoogleScholar(BaseExpensiveTest): def _test_tool(self): query = 'Can you search Google Scholar for articles about pupillometry in psychology? What is the title of the review article by Mathôt?' for reply, metadata in self.heymans.send_user_message(query): - if 'Pupillometry: Psychology, physiology, and function' in reply: - break - else: - assert False + print(reply) + assert 'Pupillometry: Psychology, physiology, and function' in reply