diff --git a/heymans/config.py b/heymans/config.py index 6b95165..762d700 100644 --- a/heymans/config.py +++ b/heymans/config.py @@ -119,8 +119,8 @@ 'answer_model': 'gpt-4' }, 'anthropic': { - 'search_model': 'mistral-medium', - 'condense_model': 'mistral-medium', + 'search_model': 'claude-3-sonnet', + 'condense_model': 'claude-3-sonnet', 'answer_model': 'claude-3-opus' }, 'mistral': { @@ -134,6 +134,8 @@ 'answer_model': 'dummy' } } +# Model-specific arguments +anthropic_max_tokens = 1024 # TOOLS # diff --git a/heymans/model/_anthropic_model.py b/heymans/model/_anthropic_model.py index e8f9cd8..c6b048b 100644 --- a/heymans/model/_anthropic_model.py +++ b/heymans/model/_anthropic_model.py @@ -1,28 +1,116 @@ from . import BaseModel -from .. import config +from .. import config, utils +import logging +logger = logging.getLogger('heymans') class AnthropicModel(BaseModel): - max_retry = 3 - - def __init__(self, heymans, model): - from langchain_anthropic import ChatAnthropic - super().__init__(heymans) - self._model = ChatAnthropic( - model=model, anthropic_api_key=config.anthropic_api_key) + def __init__(self, heymans, model, **kwargs): + from anthropic import Anthropic, AsyncAnthropic + super().__init__(heymans, **kwargs) + self._model = model + self._tool_use_id = 0 + self._client = Anthropic(api_key=config.anthropic_api_key) + self._async_client = AsyncAnthropic(api_key=config.anthropic_api_key) def predict(self, messages): - if isinstance(messages, list): - messages = utils.prepare_messages(messages, allow_ai_first=False, - allow_ai_last=False, - merge_consecutive=True) - # Claude seems to crash occasionally, in which case a retry will do the - # trick - for i in range(self.max_retry): - try: - return super().predict(messages) - except Exception as e: - logger.warning(f'error in prediction (retrying): {e}') + if isinstance(messages, str): + return super().predict([self.convert_message(messages)]) + messages = utils.prepare_messages(messages, allow_ai_first=False, + allow_ai_last=False, + merge_consecutive=True) + messages = [self.convert_message(message) for message in messages] + # The Anthropic messages API doesn't accept tool results in a separate + # message. Instead, tool results are included as a special content + # block in a user message. Since two subsequent user messages aren't + # allowed, we need to convert a tool message to a user message and if + # necessary merge it with the next user message. + while True: + logger.info('entering message postprocessing loop') + for i, message in enumerate(messages): + if message['role'] == 'tool': + logger.info('converting tool message to user message') + message['role'] = 'user' + message['content'] = [{ + 'type': 'tool_result', + 'tool_use_id': str(self._tool_use_id), + 'content': [{ + 'type': 'text', + 'text': message['content'] + }] + }] + if i > 0: + # The previous message needs to have a tool-use block + prev_message = messages[i - 1] + prev_message['content'] = [ + {'type': 'text', + 'text': prev_message['content']}, + {'type': 'tool_use', + 'id': str(self._tool_use_id), + 'input': {'args': 'input args'}, + 'name': 'tool_function' + } + ] + self._tool_use_id += 1 + if len(messages) > i + 1: + logger.info('merging tool and user message') + next_message = messages[i + 1] + if next_message['role'] == 'user': + message['content'].append([{ + "type": "text", + "text": next_message['content'] + }]) + break + else: + break + logger.info('dropping duplicate user message') + messages.remove(next_message) return super().predict(messages) + def get_response(self, response): + text = [] + for block in response.content: + if block.type == 'tool_use': + for tool in self._tools: + if tool.name == block.name: + return tool.bind(block.input) + return self.invalid_tool + if block.type == 'text': + text.append(block.text) + return '\n'.join(text) + + + def _tool_args(self): + if not self._tools: + return {} + alternative_format_tools = [] + for tool in self.tools(): + if tool['type'] == 'function': + function = tool['function'] + alt_tool = { + "name": function['name'], + "description": function['description'], + "input_schema": function['parameters'] + } + alternative_format_tools.append(alt_tool) + return {'tools': alternative_format_tools} + + def _anthropic_invoke(self, fnc, messages): + kwargs = self._tool_args() + # If the first message is the system prompt, we need to separate this + # from the user and assistant messages, because the Anthropic messages + # API takes this as a separate keyword argument + if messages[0]['role'] == 'system': + kwargs['system'] = messages[0]['content'] + messages = messages[1:] + return fnc(model=self._model, max_tokens=config.anthropic_max_tokens, + messages=messages, **kwargs) + + def invoke(self, messages): + return self._anthropic_invoke( + self._client.beta.tools.messages.create, messages) + + def async_invoke(self, messages): + return self._anthropic_invoke( + self._async_client.beta.tools.messages.create, messages) diff --git a/heymans/model/_base_model.py b/heymans/model/_base_model.py index f7f875e..6458e9a 100644 --- a/heymans/model/_base_model.py +++ b/heymans/model/_base_model.py @@ -1,6 +1,8 @@ import logging import asyncio import time +from langchain.schema import SystemMessage, AIMessage, HumanMessage, \ + FunctionMessage logger = logging.getLogger('heymans') @@ -17,7 +19,7 @@ def __init__(self, heymans, tools=None, tool_choice='auto'): self.prompt_tokens_consumed = 0 self.completion_tokens_consumed = 0 - def invalid_tool(self): + def invalid_tool(self) -> str: return 'Invalid tool' def get_response(self, response) -> [str, callable]: @@ -28,16 +30,31 @@ def tools(self): for t in self._tools if t.tool_spec] def invoke(self, messages): - return self._model.invoke(messages) + raise NotImplementedError() def async_invoke(self, messages): - return self._model.ainvoke(messages) + raise NotImplementedError() - def messages_length(self, messages): + def messages_length(self, messages) -> int: if isinstance(messages, str): - return len(messages) + return lebase_format_toolsn(messages) return sum([len(m.content if hasattr(m, 'content') else m['content']) for m in messages]) + + def convert_message(self, message): + if isinstance(message, str): + return dict(role='user', content=message) + if isinstance(message, SystemMessage): + role = 'system' + elif isinstance(message, AIMessage): + role = 'assistant' + elif isinstance(message, HumanMessage): + role = 'user' + elif isinstance(message, FunctionMessage): + role = 'tool' + else: + raise ValueError(f'Unknown message type: {message}') + return dict(role=role, content=message.content) def predict(self, messages, track_tokens=True): t0 = time.time() @@ -63,6 +80,7 @@ def predict_multiple(self, prompts): """Predicts multiple simple (non-message history) prompts using asyncio if possible. """ + prompts = [[self.convert_message(prompt)] for prompt in prompts] try: loop = asyncio.get_event_loop() if not loop.is_running(): diff --git a/heymans/model/_openai_model.py b/heymans/model/_openai_model.py index 3717615..9180caf 100644 --- a/heymans/model/_openai_model.py +++ b/heymans/model/_openai_model.py @@ -1,7 +1,5 @@ from .. import config from . import BaseModel -from langchain.schema import SystemMessage, AIMessage, HumanMessage, \ - FunctionMessage class OpenAIModel(BaseModel): @@ -18,22 +16,6 @@ def __init__(self, heymans, model, **kwargs): self._client = Client(api_key=config.openai_api_key) self._async_client = AsyncClient(api_key=config.openai_api_key) - def convert_message(self, message): - # OpenAI expects messages as dict objects - if isinstance(message, str): - return dict(role='user', content=message) - if isinstance(message, SystemMessage): - role = 'system' - elif isinstance(message, AIMessage): - role = 'assistant' - elif isinstance(message, HumanMessage): - role = 'user' - elif isinstance(message, FunctionMessage): - role = 'tool' - else: - raise ValueError(f'Unknown message type: {message}') - return dict(role=role, content=message.content) - def predict(self, messages): # Strings need to be converted a list of length one with a single # message dict @@ -61,10 +43,6 @@ def predict(self, messages): message['tool_call_id'] = tool_call_id return super().predict(messages) - def predict_multiple(self, prompts): - prompts = [[self.convert_message(prompt)] for prompt in prompts] - return super().predict_multiple(prompts) - def get_response(self, response): tool_calls = response.choices[0].message.tool_calls if tool_calls: diff --git a/heymans/tools/_base_tool.py b/heymans/tools/_base_tool.py index f795d44..c3b9680 100644 --- a/heymans/tools/_base_tool.py +++ b/heymans/tools/_base_tool.py @@ -31,6 +31,7 @@ def name(self): return self.__class__.__name__ def bind(self, args): + print(f'binding tool to: {args}') if isinstance(args, str): args = json.loads(args) return functools.partial(self, **args)