diff --git a/heymans/config.py b/heymans/config.py index 762d700..ee83af7 100644 --- a/heymans/config.py +++ b/heymans/config.py @@ -228,7 +228,7 @@ def validate_user(username, password): # SUBSCRIPTIONS # # Enable this to activate the Stripe-based subscription functionality. -subscription_required = True +subscription_required = False # This is the duration of the subscription in days. This should be set to a bit # longer than a month to provide a grace period in case of payment issues. subscription_length = 40 diff --git a/heymans/documentation.py b/heymans/documentation.py index d5c3d12..e647690 100644 --- a/heymans/documentation.py +++ b/heymans/documentation.py @@ -64,11 +64,13 @@ def strip_irrelevant(self, question): replies = self._heymans.condense_model.predict_multiple(prompts) for reply, doc in zip(replies, optional): doc_desc = f'{doc.metadata["url"]} ({doc.metadata["seq_num"]})' - if not reply.lower().startswith('no'): + if reply.lower().startswith('no'): important.append(doc) - logger.info(f'keeping {doc_desc}') + logger.info(f'keeping {doc_desc} (reply: {reply})') + elif reply.lower().startswith('yes'): + logger.info(f'stripping {doc_desc} (reply: {reply})') else: - logger.info(f'stripping {doc_desc}') + logger.warning(f'invalid reply: {reply}') self._documents = important def clear(self): diff --git a/heymans/heymans.py b/heymans/heymans.py index b7af154..2ecb844 100644 --- a/heymans/heymans.py +++ b/heymans/heymans.py @@ -161,7 +161,8 @@ def _answer(self, state: str = 'answer') -> GeneratorType: if tool_result: metadata = self.messages.append('tool', json.dumps(tool_result)) - yield tool_result['content'], metadata + if tool_result['content']: + yield tool_result['content'], metadata # Otherwise the reply is a regular AI message else: metadata = self.messages.append('assistant', reply) diff --git a/heymans/messages.py b/heymans/messages.py index 8813fd7..6303cb3 100644 --- a/heymans/messages.py +++ b/heymans/messages.py @@ -109,6 +109,20 @@ def prompt(self, system_prompt=None): else: raise ValueError(f'Invalid role: {role}') return model_prompt + + def visible_messages(self): + """Yields role, message, metadata while ignoring messages and + converting tool messages into user messages with tool result as + content. This is mainly for display in the web interface. + """ + for role, message, metadata in self: + if role == 'tool': + role = 'assistant' + tool_results = json.loads(message) + message = tool_results['content'] + if not message.strip(): + continue + yield role, message, metadata def welcome_message(self): if self._heymans.search_first: @@ -161,8 +175,6 @@ def _system_prompt(self): # If available, documentation is also included in the prompt if len(self._heymans.documentation): system_prompt.append(self._heymans.documentation.prompt()) - system_prompt.append( - attachments.attachments_prompt(self._heymans.database)) # And finally, if the message history has been condensed, this is also # included. if self._condensed_text: diff --git a/heymans/model/_anthropic_model.py b/heymans/model/_anthropic_model.py index e0cf8b1..5a1c4bc 100644 --- a/heymans/model/_anthropic_model.py +++ b/heymans/model/_anthropic_model.py @@ -63,10 +63,10 @@ def predict(self, messages): next_message = messages[i + 1] if next_message['role'] == 'user': logger.info('merging tool and user message') - message['content'].append([{ + message['content'].append({ "type": "text", "text": next_message['content'] - }]) + }) break else: break diff --git a/heymans/model/_base_model.py b/heymans/model/_base_model.py index ee69020..b905cdb 100644 --- a/heymans/model/_base_model.py +++ b/heymans/model/_base_model.py @@ -31,6 +31,8 @@ def get_response(self, response) -> [str, callable]: return response.content def tools(self): + if self._tools is None: + return [] return [{"type": "function", "function": t.tool_spec} for t in self._tools if t.tool_spec] diff --git a/heymans/model/_mistral_model.py b/heymans/model/_mistral_model.py index af83f00..2f48b6b 100644 --- a/heymans/model/_mistral_model.py +++ b/heymans/model/_mistral_model.py @@ -1,14 +1,15 @@ from .. import config, utils from . import BaseModel from ._openai_model import OpenAIModel +import logging from langchain.schema import SystemMessage, AIMessage, HumanMessage, \ FunctionMessage +logger = logging.getLogger('heymans') class MistralModel(OpenAIModel): supports_not_done_yet = False - supports_tool_feedback = False def __init__(self, heymans, model, **kwargs): from mistralai.async_client import MistralAsyncClient @@ -29,12 +30,32 @@ def predict(self, messages): merge_consecutive=True) messages = [self.convert_message(message) for message in messages] messages = self._prepare_tool_messages(messages) + # Mistral requires an assistant message after a tool message + while True: + for i, message in enumerate(messages[:-1]): + next_message = messages[i + 1] + if message['role'] == 'tool' and \ + next_message['role'] == 'user': + break + else: + break + logger.info('adding assistant message between tool and user') + messages.insert(i + 1, {'role': 'assistant', + 'content': 'Tool was executed.'}) return BaseModel.predict(self, messages) + + def _mistral_tool_args(self, messages): + # Mistral tends to get stuck in a loop where the same tool is called + # over and over again. To fix this, we temporarily disallow tools when + # the last message was a tool. + if messages[-1]['role'] == 'tool': + return {} + return self._tool_args() def invoke(self, messages): return self._client.chat(model=self._model, messages=messages, - **self._tool_args()) + **self._mistral_tool_args(messages)) def async_invoke(self, messages): return self._async_client.chat(model=self._model, messages=messages, - **self._tool_args()) + **self._mistral_tool_args(messages)) diff --git a/heymans/routes/app.py b/heymans/routes/app.py index 964ddb6..4549d3a 100644 --- a/heymans/routes/app.py +++ b/heymans/routes/app.py @@ -35,7 +35,7 @@ def chat_page(): html_content = '' previous_timestamp = None previous_answer_model = None - for role, message, metadata in heymans.messages: + for role, message, metadata in heymans.messages.visible_messages(): message_id = metadata.get('message_id', 0) delete_button = f'' if role in ('assistant', 'tool'): diff --git a/heymans/tools/_base_tool.py b/heymans/tools/_base_tool.py index f052524..5e7b363 100644 --- a/heymans/tools/_base_tool.py +++ b/heymans/tools/_base_tool.py @@ -42,7 +42,12 @@ def bind(self, args: str) -> callable: called again to provide feedback based on the tool result. """ def bound_tool_function(): - message, result, needs_feedback = self(**json.loads(args)) + try: + message, result, needs_feedback = self(**json.loads(args)) + except Exception as e: + message = 'Failed to run tool' + result = f'The following error occurred while trying to run tool:\n\n{e}' + needs_feedback = True result = {'name': self.name, 'args': args, 'content': result} diff --git a/heymans/tools/_download.py b/heymans/tools/_download.py index 1cdd8d4..648963b 100644 --- a/heymans/tools/_download.py +++ b/heymans/tools/_download.py @@ -56,11 +56,7 @@ def _download(self, url): raise def __call__(self, url): - try: - filename, content = self._download(url) - except Exception as e: - return f'I failed to download the file for the following reason: {e}', \ - True + filename, content = self._download(url) description = attachments.describe_file(filename, content, self._heymans.condense_model) attachment_data = { @@ -70,4 +66,4 @@ def __call__(self, url): } self._heymans.database.add_attachment(attachment_data) return f'''I have downloaded {filename} and added it to my attachments.''', \ - None, False + '', False diff --git a/heymans/tools/_execute_code.py b/heymans/tools/_execute_code.py index 3e56213..c3c732c 100644 --- a/heymans/tools/_execute_code.py +++ b/heymans/tools/_execute_code.py @@ -58,4 +58,4 @@ def __call__(self, language, code): ''' return 'Executing code ...', result, True logger.error(f"Error: {response.status_code} with message: {response.content}") - return 'Failed to execute code', None, True + return 'Failed to execute code', '', True diff --git a/heymans/tools/_search_documentation.py b/heymans/tools/_search_documentation.py index dd8a746..42ec43f 100644 --- a/heymans/tools/_search_documentation.py +++ b/heymans/tools/_search_documentation.py @@ -61,4 +61,4 @@ def __call__(self, primary_topic, queries, secondary_topic=None): page_content=Path(config.topic_sources[topic]).read_text()) doc.metadata['important'] = True self._heymans.documentation.append(doc) - return 'Searching documentation ...', None, False + return 'Searching documentation ...', '', False diff --git a/tests/expensive/test_tool_attachment.py b/tests/expensive/test_tool_attachment.py index ded5a22..8d7e52a 100644 --- a/tests/expensive/test_tool_attachment.py +++ b/tests/expensive/test_tool_attachment.py @@ -19,3 +19,9 @@ def _test_tool(self): break else: assert False + query = 'Can you download the following readme: https://raw.githubusercontent.com/open-cogsci/OpenSesame/milgram/readme.md' + for reply, metadata in self.heymans.send_user_message(query): + print(reply) + query = 'Can you read and summarize the readme for me?' + for reply, metadata in self.heymans.send_user_message(query): + print(reply)