Skip to content

Commit

Permalink
Cheap and expensive unittests work now
Browse files Browse the repository at this point in the history
  • Loading branch information
smathot committed Apr 16, 2024
1 parent 7360b96 commit f971634
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 154 deletions.
14 changes: 10 additions & 4 deletions heymans/heymans.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import jinja2
import json
from types import GeneratorType
from typing import Tuple, Optional
from . import config, library
Expand Down Expand Up @@ -147,15 +148,20 @@ def _answer(self, state: str = 'answer') -> GeneratorType:
logger.info(f'[{state} state] reply: {reply}')
# If the reply is a callable, then it's a tool that we need to run
if callable(reply):
tool_message, result, needs_feedback = reply()
tool_message, tool_result, needs_feedback = reply()
if needs_feedback:
logger.info(f'[{state} state] tools need feedback')
if not self.answer_model.supports_tool_feedback:
logger.info(
f'[{state} state] model does not support feedback')
needs_feedback = False
metadata = self.messages.append('assistant', tool_message)
yield tool_message, metadata
# If the tool has a result, yield and remember it
if result:
metadata = self.messages.append('tool', result)
yield result, metadata
if tool_result:
metadata = self.messages.append('tool',
json.dumps(tool_result))
yield tool_result['content'], metadata
# Otherwise the reply is a regular AI message
else:
metadata = self.messages.append('assistant', reply)
Expand Down
6 changes: 1 addition & 5 deletions heymans/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,11 @@ def _system_prompt(self):
"""The system prompt that is used for question answering consists of
several fragments.
"""
# There is always and identity, information about the current time,
# and a list of attached files
# There is always and identity and a list of attached files
if self._heymans.search_first:
system_prompt = [prompt.SYSTEM_PROMPT_IDENTITY_WITH_SEARCH]
else:
system_prompt = [prompt.SYSTEM_PROMPT_IDENTITY_WITHOUT_SEARCH]
system_prompt.append(
prompt.render(prompt.SYSTEM_PROMPT_DATETIME,
current_datetime=utils.current_datetime()))
system_prompt.append(
attachments.attachments_prompt(self._heymans.database))
# For models that support this, there is also an instruction indicating
Expand Down
17 changes: 8 additions & 9 deletions heymans/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
from ._base_model import BaseModel
from ._openai_model import OpenAIModel
from ._mistral_model import MistralModel
from ._anthropic_model import AnthropicModel


class DummyModel(BaseModel):
def predict(self, messages):
return 'dummy reply'


def model(heymans, model, **kwargs):

"""A factory function that returns a Model instance."""
if model == 'gpt-4':
from ._openai_model import OpenAIModel
return OpenAIModel(heymans, 'gpt-4-1106-preview', **kwargs)
if model == 'gpt-3.5':
from ._openai_model import OpenAIModel
return OpenAIModel(heymans, 'gpt-3.5-turbo-1106', **kwargs)
if model == 'claude-2.1':
from ._anthropic_model import AnthropicModel
return AnthropicModel(heymans, 'claude-2.1', **kwargs)
if model == 'claude-3-opus':
from ._anthropic_model import AnthropicModel
return AnthropicModel(heymans, 'claude-3-opus-20240229', **kwargs)
if model == 'claude-3-sonnet':
from ._anthropic_model import AnthropicModel
return AnthropicModel(heymans, 'claude-3-sonnet-20240229', **kwargs)
if model.startswith('mistral-'):
from ._mistral_model import MistralModel
if not model.endswith('-latest'):
model += '-latest'
return MistralModel(heymans, model, **kwargs)
if model == 'dummy':
from ._dummy_model import DummyModel
return DummyModel(heymans, **kwargs)
raise ValueError(f'Unknown model: {model}')
38 changes: 22 additions & 16 deletions heymans/model/_anthropic_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from . import BaseModel
from .. import config, utils
import logging
import json
logger = logging.getLogger('heymans')


class AnthropicModel(BaseModel):

supports_not_done_yet = False

def __init__(self, heymans, model, **kwargs):
from anthropic import Anthropic, AsyncAnthropic
super().__init__(heymans, **kwargs)
Expand All @@ -30,38 +33,41 @@ def predict(self, messages):
logger.info('entering message postprocessing loop')
for i, message in enumerate(messages):
if message['role'] == 'tool':
if i == 0:
raise ValueError(
'The first message cannot be a tool message')
logger.info('converting tool message to user message')
tool_info = json.loads(message['content'])
message['role'] = 'user'
message['content'] = [{
'type': 'tool_result',
'tool_use_id': str(self._tool_use_id),
'content': [{
'type': 'text',
'text': message['content']
'text': tool_info['content']
}]
}]
if i > 0:
# The previous message needs to have a tool-use block
prev_message = messages[i - 1]
prev_message['content'] = [
{'type': 'text',
'text': prev_message['content']},
{'type': 'tool_use',
'id': str(self._tool_use_id),
'input': {'args': 'input args'},
'name': 'tool_function'
}
]
# The previous message needs to have a tool-use block
prev_message = messages[i - 1]
prev_message['content'] = [
{'type': 'text',
'text': prev_message['content']},
{'type': 'tool_use',
'id': str(self._tool_use_id),
'input': {'args': tool_info['args']},
'name': tool_info['name']
}
]
self._tool_use_id += 1
if len(messages) > i + 1:
logger.info('merging tool and user message')
next_message = messages[i + 1]
if next_message['role'] == 'user':
logger.info('merging tool and user message')
message['content'].append([{
"type": "text",
"text": next_message['content']
}])
break
break
else:
break
logger.info('dropping duplicate user message')
Expand All @@ -74,7 +80,7 @@ def get_response(self, response):
if block.type == 'tool_use':
for tool in self._tools:
if tool.name == block.name:
return tool.bind(block.input)
return tool.bind(json.dumps(block.input))
return self.invalid_tool
if block.type == 'text':
text.append(block.text)
Expand Down
11 changes: 8 additions & 3 deletions heymans/model/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@


class BaseModel:
"""Base implementation for LLM chat models."""

supports_not_done_yet = False
# Indicates whether the model is able to provide feedback on its own output
supports_not_done_yet = True
# Indicates whether the model is able to provide feedback on tool results
supports_tool_feedback = True
# Approximation to keep track of token costs
characters_per_token = 4

def __init__(self, heymans, tools=None, tool_choice='auto'):
Expand All @@ -20,7 +25,7 @@ def __init__(self, heymans, tools=None, tool_choice='auto'):
self.completion_tokens_consumed = 0

def invalid_tool(self) -> str:
return 'Invalid tool'
return 'Invalid tool', None, False

def get_response(self, response) -> [str, callable]:
return response.content
Expand All @@ -37,7 +42,7 @@ def async_invoke(self, messages):

def messages_length(self, messages) -> int:
if isinstance(messages, str):
return lebase_format_toolsn(messages)
return len(messages)
return sum([len(m.content if hasattr(m, 'content') else m['content'])
for m in messages])

Expand Down
17 changes: 17 additions & 0 deletions heymans/model/_dummy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ._base_model import BaseModel


class DummyModel(BaseModel):

def get_response(self, response):
return response

def invoke(self, messages):
return 'dummy reply'

async def _async_task(self):
return 'dummy reply'

def async_invoke(self, messages):
import asyncio
return asyncio.create_task(self._async_task())
12 changes: 5 additions & 7 deletions heymans/model/_mistral_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .. import config, utils
from . import OpenAIModel, BaseModel
from . import BaseModel
from ._openai_model import OpenAIModel
from langchain.schema import SystemMessage, AIMessage, HumanMessage, \
FunctionMessage


class MistralModel(OpenAIModel):

supports_not_done_yet = False
supports_tool_feedback = False

def __init__(self, heymans, model, **kwargs):
from mistralai.async_client import MistralAsyncClient
Expand All @@ -18,19 +20,15 @@ def __init__(self, heymans, model, **kwargs):
self._client = MistralClient(api_key=config.mistral_api_key)
self._async_client = MistralAsyncClient(api_key=config.mistral_api_key)

def convert_message(self, message):
from mistralai.models.chat_completion import ChatMessage
message = super().convert_message(message)
return ChatMessage(**message)

def predict(self, messages):
if isinstance(messages, str):
messages = [self.convert_message(messages)]
else:
messages = utils.prepare_messages(messages, allow_ai_first=False,
allow_ai_last=False,
merge_consecutive=True)
messages = [self.convert_message(message) for message in messages]
messages = [self.convert_message(message) for message in messages]
messages = self._prepare_tool_messages(messages)
return BaseModel.predict(self, messages)

def invoke(self, messages):
Expand Down
49 changes: 29 additions & 20 deletions heymans/model/_openai_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from .. import config
from . import BaseModel


class OpenAIModel(BaseModel):

supports_not_done_yet = True

def __init__(self, heymans, model, **kwargs):
from openai import Client, AsyncClient
super().__init__(heymans, **kwargs)
Expand All @@ -23,33 +22,43 @@ def predict(self, messages):
messages = [self.convert_message(messages)]
else:
messages = [self.convert_message(message) for message in messages]
# OpenAI requires the tool message to be linked to the previous AI
# message with a tool_call_id. The actual content doesn't appear to
# matter, so here we dummy-link the messages
for i, message in enumerate(messages):
if i == 0 or message['role'] != 'tool':
continue
tool_call_id = f'call_{i}'
prev_message = messages[i - 1]
prev_message['tool_calls'] = [
{
'id': tool_call_id,
'type': 'function',
'function': {
'name': 'dummy',
'arguments': ''
}
}]
message['tool_call_id'] = tool_call_id
messages = self._prepare_tool_messages(messages)
return super().predict(messages)

def _prepare_tool_messages(self, messages):
# OpenAI requires the tool message to be linked to the previous AI
# message with a tool_call_id. The actual content doesn't appear to
# matter, so here we dummy-link the messages
for i, message in enumerate(messages):
if i == 0 or message['role'] != 'tool':
continue
tool_info = json.loads(message['content'])
tool_call_id = f'call_{i}'
prev_message = messages[i - 1]
# an assistant message should not have both content and tool calls
prev_message['content'] = ''
prev_message['tool_calls'] = [
{
'id': tool_call_id,
'type': 'function',
'function': {
'name': tool_info['name'],
'arguments': tool_info['args']
}
}]
message['tool_call_id'] = tool_call_id
message['name'] = tool_info['name']
message['content'] = tool_info['content']
return messages

def get_response(self, response):
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
function = tool_calls[0].function
for tool in self._tools:
if tool.name == function.name:
return tool.bind(function.arguments)
logger.warning(f'invalid tool called: {function}')
return self.invalid_tool
return response.choices[0].message.content

Expand Down
8 changes: 1 addition & 7 deletions heymans/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,10 @@
# The system prompt used during question answering is composed of the fragments
# below
SYSTEM_PROMPT_IDENTITY_WITH_SEARCH = '''You are Sigmund, a brilliant AI assistant for users of OpenSesame, a program for building psychology and neuroscience experiments. You sometimes use emojis.'''
SYSTEM_PROMPT_IDENTITY_WITHOUT_SEARCH = '''You are Sigmund, a brilliant AI assistant. You sometimes use emojis.
If and only if the user asks a question related to OpenSesame, a program for building psychology and neuroscience experiments, don't answer the question. Instead, suggest that the user gives you access to the OpenSesame documentation by enabling OpenSesame expert mode through the menu of this chat web application.
'''
SYSTEM_PROMPT_IDENTITY_WITHOUT_SEARCH = '''You are Sigmund, a brilliant AI assistant. You sometimes use emojis.'''
# Sent by AI to indicate that message requires for replies or actions
NOT_DONE_YET_MARKER = '<NOT_DONE_YET>'
SYSTEM_PROMPT_NOT_DONE_YET = f'''When you intend to perform an action ("please wait", "I will now"), such as searching or code execution, end your reply with {NOT_DONE_YET_MARKER}.'''
SYSTEM_PROMPT_DATETIME = '''# Current date and time
{{ current_datetime }}'''
SYSTEM_PROMPT_ATTACHMENTS = '''# Attachments
You have access to the following attached files:
Expand Down
27 changes: 20 additions & 7 deletions heymans/tools/_base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
class BaseTool:
"""A base class for tools that process an AI reply."""

tool_spec = None

def __init__(self, heymans):
self._heymans = heymans

@property
def tool_spec(self):
"""The tool spec should corresond to the OpenAI specification for
function tools.
"""
return {
"name": self.__class__.__name__,
"description": self.__doc__,
Expand All @@ -30,11 +31,23 @@ def tool_spec(self):
def name(self):
return self.__class__.__name__

def bind(self, args):
print(f'binding tool to: {args}')
if isinstance(args, str):
args = json.loads(args)
return functools.partial(self, **args)
def bind(self, args: str) -> callable:
"""Returns a callable that corresponds to a tool function called with
a string of arguments, which should be in JSON format. The callable
itself returns a (message: str, result: dict, needs_feedback: bool)
tuple, where message is an informative text message as generated by
the tool, result is a dict with name, args, and content keys that
correspond to the name and arguments of the function and the result of
the tool call. needs_feedback indicates whether the model should be
called again to provide feedback based on the tool result.
"""
def bound_tool_function():
message, result, needs_feedback = self(**json.loads(args))
result = {'name': self.name,
'args': args,
'content': result}
return message, result, needs_feedback
return bound_tool_function

def __call__(self) -> Tuple[str, Optional[str], bool]:
"""Should be implemented in a tool with additional arguments that
Expand Down
Loading

0 comments on commit f971634

Please sign in to comment.