Skip to content

Commit

Permalink
Allow model-specific generation keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
smathot committed Apr 17, 2024
1 parent 0729100 commit ef3e5e4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
9 changes: 7 additions & 2 deletions heymans/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,13 @@
'answer_model': 'dummy'
}
}
# Model-specific arguments
anthropic_max_tokens = 1024
# Model-specific keyword arguments that are passed to the message generation
# functions
anthropic_kwargs = {
'max_tokens': 1024
}
openai_kwargs = {}
mistral_kwargs = {}

# TOOLS
#
Expand Down
4 changes: 2 additions & 2 deletions heymans/model/_anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def _tool_args(self):

def _anthropic_invoke(self, fnc, messages):
kwargs = self._tool_args()
kwargs.update(config.anthropic_kwargs)
# If the first message is the system prompt, we need to separate this
# from the user and assistant messages, because the Anthropic messages
# API takes this as a separate keyword argument
if messages[0]['role'] == 'system':
kwargs['system'] = messages[0]['content']
messages = messages[1:]
return fnc(model=self._model, max_tokens=config.anthropic_max_tokens,
messages=messages, **kwargs)
return fnc(model=self._model, messages=messages, **kwargs)

def invoke(self, messages):
return self._anthropic_invoke(
Expand Down
15 changes: 8 additions & 7 deletions heymans/model/_mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ def predict(self, messages):
'content': 'Tool was executed.'})
return BaseModel.predict(self, messages)

def _mistral_tool_args(self, messages):
def _mistral_invoke(self, fnc, messages):
# Mistral tends to get stuck in a loop where the same tool is called
# over and over again. To fix this, we temporarily disallow tools when
# the last message was a tool.
if messages[-1]['role'] == 'tool':
return {}
return self._tool_args()
kwargs = {}
else:
kwargs = self._tool_args()
kwargs.update(config.mistral_kwargs)
return fnc(model=self._model, messages=messages, **kwargs)

def invoke(self, messages):
return self._client.chat(model=self._model, messages=messages,
**self._mistral_tool_args(messages))
return self._mistral_invoke(self._client.chat, messages)

def async_invoke(self, messages):
return self._async_client.chat(model=self._model, messages=messages,
**self._mistral_tool_args(messages))
return self._mistral_invoke(self._async_client.chat, messages)
13 changes: 9 additions & 4 deletions heymans/model/_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ def _tool_args(self):
return {}
return {'tools': self.tools(), 'tool_choice': self._tool_choice}

def _openai_invoke(self, fnc, messages):
kwargs = self._tool_args()
kwargs.update(config.openai_kwargs)
return fnc(model=self._model, messages=messages, **kwargs)

def invoke(self, messages):
return self._client.chat.completions.create(
model=self._model, messages=messages, **self._tool_args())
return self._openai_invoke(
self._client.chat.completions.create, messages=messages)

def async_invoke(self, messages):
return self._async_client.chat.completions.create(
model=self._model, messages=messages, **self._tool_args())
return self._openai_invoke(
self._async_client.chat.completions.create, messages=messages)

0 comments on commit ef3e5e4

Please sign in to comment.