Skip to content

Commit

Permalink
Various updated to improve Mistral tool use
Browse files Browse the repository at this point in the history
  • Loading branch information
smathot committed Apr 16, 2024
1 parent f971634 commit 50c71ad
Show file tree
Hide file tree
Showing 13 changed files with 67 additions and 22 deletions.
2 changes: 1 addition & 1 deletion heymans/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def validate_user(username, password):
# SUBSCRIPTIONS
#
# Enable this to activate the Stripe-based subscription functionality.
subscription_required = True
subscription_required = False
# This is the duration of the subscription in days. This should be set to a bit
# longer than a month to provide a grace period in case of payment issues.
subscription_length = 40
Expand Down
8 changes: 5 additions & 3 deletions heymans/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def strip_irrelevant(self, question):
replies = self._heymans.condense_model.predict_multiple(prompts)
for reply, doc in zip(replies, optional):
doc_desc = f'{doc.metadata["url"]} ({doc.metadata["seq_num"]})'
if not reply.lower().startswith('no'):
if reply.lower().startswith('no'):
important.append(doc)
logger.info(f'keeping {doc_desc}')
logger.info(f'keeping {doc_desc} (reply: {reply})')
elif reply.lower().startswith('yes'):
logger.info(f'stripping {doc_desc} (reply: {reply})')
else:
logger.info(f'stripping {doc_desc}')
logger.warning(f'invalid reply: {reply}')
self._documents = important

def clear(self):
Expand Down
3 changes: 2 additions & 1 deletion heymans/heymans.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def _answer(self, state: str = 'answer') -> GeneratorType:
if tool_result:
metadata = self.messages.append('tool',
json.dumps(tool_result))
yield tool_result['content'], metadata
if tool_result['content']:
yield tool_result['content'], metadata
# Otherwise the reply is a regular AI message
else:
metadata = self.messages.append('assistant', reply)
Expand Down
16 changes: 14 additions & 2 deletions heymans/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ def prompt(self, system_prompt=None):
else:
raise ValueError(f'Invalid role: {role}')
return model_prompt

def visible_messages(self):
"""Yields role, message, metadata while ignoring messages and
converting tool messages into user messages with tool result as
content. This is mainly for display in the web interface.
"""
for role, message, metadata in self:
if role == 'tool':
role = 'assistant'
tool_results = json.loads(message)
message = tool_results['content']
if not message.strip():
continue
yield role, message, metadata

def welcome_message(self):
if self._heymans.search_first:
Expand Down Expand Up @@ -161,8 +175,6 @@ def _system_prompt(self):
# If available, documentation is also included in the prompt
if len(self._heymans.documentation):
system_prompt.append(self._heymans.documentation.prompt())
system_prompt.append(
attachments.attachments_prompt(self._heymans.database))
# And finally, if the message history has been condensed, this is also
# included.
if self._condensed_text:
Expand Down
4 changes: 2 additions & 2 deletions heymans/model/_anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def predict(self, messages):
next_message = messages[i + 1]
if next_message['role'] == 'user':
logger.info('merging tool and user message')
message['content'].append([{
message['content'].append({
"type": "text",
"text": next_message['content']
}])
})
break
else:
break
Expand Down
2 changes: 2 additions & 0 deletions heymans/model/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def get_response(self, response) -> [str, callable]:
return response.content

def tools(self):
if self._tools is None:
return []
return [{"type": "function", "function": t.tool_spec}
for t in self._tools if t.tool_spec]

Expand Down
27 changes: 24 additions & 3 deletions heymans/model/_mistral_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from .. import config, utils
from . import BaseModel
from ._openai_model import OpenAIModel
import logging
from langchain.schema import SystemMessage, AIMessage, HumanMessage, \
FunctionMessage
logger = logging.getLogger('heymans')


class MistralModel(OpenAIModel):

supports_not_done_yet = False
supports_tool_feedback = False

def __init__(self, heymans, model, **kwargs):
from mistralai.async_client import MistralAsyncClient
Expand All @@ -29,12 +30,32 @@ def predict(self, messages):
merge_consecutive=True)
messages = [self.convert_message(message) for message in messages]
messages = self._prepare_tool_messages(messages)
# Mistral requires an assistant message after a tool message
while True:
for i, message in enumerate(messages[:-1]):
next_message = messages[i + 1]
if message['role'] == 'tool' and \
next_message['role'] == 'user':
break
else:
break
logger.info('adding assistant message between tool and user')
messages.insert(i + 1, {'role': 'assistant',
'content': 'Tool was executed.'})
return BaseModel.predict(self, messages)

def _mistral_tool_args(self, messages):
# Mistral tends to get stuck in a loop where the same tool is called
# over and over again. To fix this, we temporarily disallow tools when
# the last message was a tool.
if messages[-1]['role'] == 'tool':
return {}
return self._tool_args()

def invoke(self, messages):
return self._client.chat(model=self._model, messages=messages,
**self._tool_args())
**self._mistral_tool_args(messages))

def async_invoke(self, messages):
return self._async_client.chat(model=self._model, messages=messages,
**self._tool_args())
**self._mistral_tool_args(messages))
2 changes: 1 addition & 1 deletion heymans/routes/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def chat_page():
html_content = ''
previous_timestamp = None
previous_answer_model = None
for role, message, metadata in heymans.messages:
for role, message, metadata in heymans.messages.visible_messages():
message_id = metadata.get('message_id', 0)
delete_button = f'<button class="message-delete" onclick="deleteMessage(\'{message_id}\')"><i class="fas fa-trash"></i></button>'
if role in ('assistant', 'tool'):
Expand Down
7 changes: 6 additions & 1 deletion heymans/tools/_base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def bind(self, args: str) -> callable:
called again to provide feedback based on the tool result.
"""
def bound_tool_function():
message, result, needs_feedback = self(**json.loads(args))
try:
message, result, needs_feedback = self(**json.loads(args))
except Exception as e:
message = 'Failed to run tool'
result = f'The following error occurred while trying to run tool:\n\n{e}'
needs_feedback = True
result = {'name': self.name,
'args': args,
'content': result}
Expand Down
8 changes: 2 additions & 6 deletions heymans/tools/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def _download(self, url):
raise

def __call__(self, url):
try:
filename, content = self._download(url)
except Exception as e:
return f'I failed to download the file for the following reason: {e}', \
True
filename, content = self._download(url)
description = attachments.describe_file(filename, content,
self._heymans.condense_model)
attachment_data = {
Expand All @@ -70,4 +66,4 @@ def __call__(self, url):
}
self._heymans.database.add_attachment(attachment_data)
return f'''I have downloaded {filename} and added it to my attachments.''', \
None, False
'', False
2 changes: 1 addition & 1 deletion heymans/tools/_execute_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def __call__(self, language, code):
'''
return 'Executing code ...', result, True
logger.error(f"Error: {response.status_code} with message: {response.content}")
return 'Failed to execute code', None, True
return 'Failed to execute code', '', True
2 changes: 1 addition & 1 deletion heymans/tools/_search_documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def __call__(self, primary_topic, queries, secondary_topic=None):
page_content=Path(config.topic_sources[topic]).read_text())
doc.metadata['important'] = True
self._heymans.documentation.append(doc)
return 'Searching documentation ...', None, False
return 'Searching documentation ...', '', False
6 changes: 6 additions & 0 deletions tests/expensive/test_tool_attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ def _test_tool(self):
break
else:
assert False
query = 'Can you download the following readme: https://raw.githubusercontent.com/open-cogsci/OpenSesame/milgram/readme.md'
for reply, metadata in self.heymans.send_user_message(query):
print(reply)
query = 'Can you read and summarize the readme for me?'
for reply, metadata in self.heymans.send_user_message(query):
print(reply)

0 comments on commit 50c71ad

Please sign in to comment.