Skip to content

Commit

Permalink
Add support for Cohere models
Browse files Browse the repository at this point in the history
Added support for Cohere models by introducing the `models/cohere.py`
file. Currently, the support is limited to non-streaming, but streaming
support is planned soon.
  • Loading branch information
rafidka committed Dec 25, 2024
1 parent b9ec73f commit dfb7326
Show file tree
Hide file tree
Showing 6 changed files with 461 additions and 3 deletions.
17 changes: 17 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@
'claude-3-5-haiku-latest',
'claude-3-5-sonnet-latest',
'claude-3-opus-latest',
'cohere:c4ai-aya-expanse-32b',
'cohere:c4ai-aya-expanse-8b',
'cohere:command',
'cohere:command-light',
'cohere:command-light-nightly',
'cohere:command-nightly',
'cohere:command-r',
'cohere:command-r-03-2024',
'cohere:command-r-08-2024',
'cohere:command-r-plus',
'cohere:command-r-plus-04-2024',
'cohere:command-r-plus-08-2024',
'cohere:command-r7b-12-2024',
'test',
]
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
Expand Down Expand Up @@ -270,6 +283,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .test import TestModel

return TestModel()
elif model.startswith('cohere:'):
from .cohere import CohereModel

return CohereModel(model[7:])
elif model.startswith('openai:'):
from .openai import OpenAIModel

Expand Down
283 changes: 283 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
from __future__ import annotations as _annotations

from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import chain
from typing import Literal, Union

from typing_extensions import assert_never

from .. import result
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponsePart,
RetryPromptPart,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
)
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
AgentModel,
Model,
check_allow_model_requests,
)

try:
from cohere import (
AssistantChatMessageV2,
AsyncClientV2,
ChatMessageV2,
ChatResponse,
SystemChatMessageV2,
ToolCallV2,
ToolCallV2Function,
ToolChatMessageV2,
ToolV2,
ToolV2Function,
UserChatMessageV2,
)
from cohere.v2.client import OMIT
except ImportError as _import_error:
raise ImportError(
'Please install `cohere` to use the Cohere model, '
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
) from _import_error

type CohereModelName = Union[
str,
Literal[
'c4ai-aya-expanse-32b',
'c4ai-aya-expanse-8b',
'command',
'command-light',
'command-light-nightly',
'command-nightly',
'command-r',
'command-r-03-2024',
'command-r-08-2024',
'command-r-plus',
'command-r-plus-04-2024',
'command-r-plus-08-2024',
'command-r7b-12-2024',
],
]

"""
Using this more broad type for the model name instead of the ChatModel definition
allows this model to be used more easily with other model types (ie, Ollama)
"""


@dataclass(init=False)
class CohereModel(Model):
"""A model that uses the Cohere API.
Internally, this uses the [Cohere Python client](
https://github.com/cohere-ai/cohere-python) to interact with the API.
Apart from `__init__`, all methods are private or match those of the base class.
"""

model_name: CohereModelName
client: AsyncClientV2 = field(repr=False)

def __init__(
self,
model_name: CohereModelName,
*,
api_key: str | None = None,
cohere_client: AsyncClientV2 | None = None,
):
"""Initialize an Cohere model.
Args:
model_name: The name of the Cohere model to use. List of model names
available [here](https://docs.cohere.com/docs/models#command).
api_key: The API key to use for authentication, if not provided, the
`COHERE_API_KEY` environment variable will be used if available.
cohere_client: An existing Cohere async client to use. If provided,
`api_key` must be `None`.
"""
self.model_name: CohereModelName = model_name
if cohere_client is not None:
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
self.client = cohere_client
else:
self.client = AsyncClientV2(api_key=api_key) # type: ignore

async def agent_model(
self,
*,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
check_allow_model_requests()
tools = [self._map_tool_definition(r) for r in function_tools]
if result_tools:
tools += [self._map_tool_definition(r) for r in result_tools]
return CohereAgentModel(
self.client,
self.model_name,
allow_text_result,
tools,
)

def name(self) -> str:
return f'cohere:{self.model_name}'

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
return ToolV2(
type='function',
function=ToolV2Function(
name=f.name,
description=f.description,
parameters=f.parameters_json_schema,
),
)


@dataclass
class CohereAgentModel(AgentModel):
"""Implementation of `AgentModel` for Cohere models."""

client: AsyncClientV2
model_name: CohereModelName
allow_text_result: bool
tools: list[ToolV2]

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Usage]:
response = await self._chat(messages, model_settings)
return self._process_response(response), _map_usage(response)

async def _chat(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
) -> ChatResponse:
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
model_settings = model_settings or {}
return await self.client.chat(
model=self.model_name,
messages=cohere_messages,
tools=self.tools or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
)

@staticmethod
def _process_response(response: ChatResponse) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
parts: list[ModelResponsePart] = []
if response.message.content is not None and len(response.message.content) > 0:
choice = response.message.content[0]
parts.append(TextPart(choice.text))
for c in response.message.tool_calls or []:
if c.function and c.function.name and c.function.arguments:
parts.append(
ToolCallPart.from_raw_args(
tool_name=c.function.name,
args=c.function.arguments,
tool_call_id=c.id,
)
)
return ModelResponse(parts=parts)

@classmethod
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
if isinstance(message, ModelRequest):
yield from cls._map_user_message(message)
elif isinstance(message, ModelResponse):
texts: list[str] = []
tool_calls: list[ToolCallV2] = []
for item in message.parts:
if isinstance(item, TextPart):
texts.append(item.content)
elif isinstance(item, ToolCallPart):
tool_calls.append(_map_tool_call(item))
else:
assert_never(item)
message_param = AssistantChatMessageV2(role='assistant')
if texts:
# Note: model responses from this model should only have one text item, so the following
# shouldn't merge multiple texts into one unless you switch models between runs:
message_param.content = '\n\n'.join(texts)
if tool_calls:
message_param.tool_calls = tool_calls
yield message_param
else:
assert_never(message)

@classmethod
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
for part in message.parts:
if isinstance(part, SystemPromptPart):
yield SystemChatMessageV2(role='system', content=part.content)
elif isinstance(part, UserPromptPart):
yield UserChatMessageV2(role='user', content=part.content)
elif isinstance(part, ToolReturnPart):
yield ToolChatMessageV2(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
content=part.model_response_str(),
)
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
yield UserChatMessageV2(role='user', content=part.model_response())
else:
yield ToolChatMessageV2(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
content=part.model_response(),
)
else:
assert_never(part)


def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
return ToolCallV2(
id=_guard_tool_call_id(t=t, model_source='Cohere'),
type='function',
function=ToolCallV2Function(
name=t.tool_name,
arguments=t.args_as_json_str(),
),
)


def _map_usage(response: ChatResponse) -> result.Usage:
usage = response.usage
if usage is None:
return result.Usage()
else:
details: dict[str, int] = {}
if usage.billed_units is not None:
if usage.billed_units.input_tokens:
details['input_tokens'] = int(usage.billed_units.input_tokens)
if usage.billed_units.output_tokens:
details['output_tokens'] = int(usage.billed_units.output_tokens)
if usage.billed_units.search_units:
details['search_units'] = int(usage.billed_units.search_units)
if usage.billed_units.classifications:
details['classifications'] = int(usage.billed_units.classifications)

request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
return result.Usage(
request_tokens=request_tokens,
response_tokens=response_tokens,
total_tokens=(request_tokens or 0) + (response_tokens or 0),
details=details,
)
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ModelSettings(TypedDict, total=False):
* Anthropic
* OpenAI
* Groq
* Cohere
"""

temperature: float
Expand All @@ -43,6 +44,7 @@ class ModelSettings(TypedDict, total=False):
* Anthropic
* OpenAI
* Groq
* Cohere
"""

top_p: float
Expand All @@ -58,6 +60,7 @@ class ModelSettings(TypedDict, total=False):
* Anthropic
* OpenAI
* Groq
* Cohere
"""

timeout: float | Timeout
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [

[project.optional-dependencies]
openai = ["openai>=1.54.3"]
cohere = ["cohere>=5.13.4"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
anthropic = ["anthropic>=0.40.0"]
groq = ["groq>=0.12.0"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers = [
]
requires-python = ">=3.9"

dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.15"]
dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.15"]

[project.urls]
Homepage = "https://ai.pydantic.dev"
Expand Down
Loading

0 comments on commit dfb7326

Please sign in to comment.