Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Cohere models #203

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@
'claude-3-5-haiku-latest',
'claude-3-5-sonnet-latest',
'claude-3-opus-latest',
'cohere:c4ai-aya-expanse-32b',
'cohere:c4ai-aya-expanse-8b',
'cohere:command',
'cohere:command-light',
'cohere:command-light-nightly',
'cohere:command-nightly',
'cohere:command-r',
'cohere:command-r-03-2024',
'cohere:command-r-08-2024',
'cohere:command-r-plus',
'cohere:command-r-plus-04-2024',
'cohere:command-r-plus-08-2024',
'cohere:command-r7b-12-2024',
'test',
]
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
Expand Down Expand Up @@ -270,6 +283,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .test import TestModel

return TestModel()
elif model.startswith('cohere:'):
from .cohere import CohereModel

return CohereModel(model[7:])
elif model.startswith('openai:'):
from .openai import OpenAIModel

Expand Down
283 changes: 283 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
from __future__ import annotations as _annotations

from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import chain
from typing import Literal, Union

from typing_extensions import assert_never

from .. import result
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponsePart,
RetryPromptPart,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
)
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
AgentModel,
Model,
check_allow_model_requests,
)

try:
from cohere import (
AssistantChatMessageV2,
AsyncClientV2,
ChatMessageV2,
ChatResponse,
SystemChatMessageV2,
ToolCallV2,
ToolCallV2Function,
ToolChatMessageV2,
ToolV2,
ToolV2Function,
UserChatMessageV2,
)
from cohere.v2.client import OMIT
except ImportError as _import_error:
raise ImportError(
'Please install `cohere` to use the Cohere model, '
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
) from _import_error

type CohereModelName = Union[
str,
Literal[
'c4ai-aya-expanse-32b',
'c4ai-aya-expanse-8b',
'command',
'command-light',
'command-light-nightly',
'command-nightly',
'command-r',
'command-r-03-2024',
'command-r-08-2024',
'command-r-plus',
'command-r-plus-04-2024',
'command-r-plus-08-2024',
'command-r7b-12-2024',
],
]

"""
Using this more broad type for the model name instead of the ChatModel definition
allows this model to be used more easily with other model types (ie, Ollama)
"""


@dataclass(init=False)
class CohereModel(Model):
"""A model that uses the Cohere API.

Internally, this uses the [Cohere Python client](
https://github.com/cohere-ai/cohere-python) to interact with the API.

Apart from `__init__`, all methods are private or match those of the base class.
"""

model_name: CohereModelName
client: AsyncClientV2 = field(repr=False)

def __init__(
self,
model_name: CohereModelName,
*,
api_key: str | None = None,
cohere_client: AsyncClientV2 | None = None,
):
"""Initialize an Cohere model.

Args:
model_name: The name of the Cohere model to use. List of model names
available [here](https://docs.cohere.com/docs/models#command).
api_key: The API key to use for authentication, if not provided, the
`COHERE_API_KEY` environment variable will be used if available.
cohere_client: An existing Cohere async client to use. If provided,
`api_key` must be `None`.
"""
self.model_name: CohereModelName = model_name
if cohere_client is not None:
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
self.client = cohere_client
else:
self.client = AsyncClientV2(api_key=api_key) # type: ignore

async def agent_model(
self,
*,
function_tools: list[ToolDefinition],
allow_text_result: bool,
result_tools: list[ToolDefinition],
) -> AgentModel:
check_allow_model_requests()
tools = [self._map_tool_definition(r) for r in function_tools]
if result_tools:
tools += [self._map_tool_definition(r) for r in result_tools]
return CohereAgentModel(
self.client,
self.model_name,
allow_text_result,
tools,
)

def name(self) -> str:
return f'cohere:{self.model_name}'

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
return ToolV2(
type='function',
function=ToolV2Function(
name=f.name,
description=f.description,
parameters=f.parameters_json_schema,
),
)


@dataclass
class CohereAgentModel(AgentModel):
"""Implementation of `AgentModel` for Cohere models."""

client: AsyncClientV2
model_name: CohereModelName
allow_text_result: bool
tools: list[ToolV2]

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Usage]:
response = await self._chat(messages, model_settings)
return self._process_response(response), _map_usage(response)

async def _chat(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
) -> ChatResponse:
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
model_settings = model_settings or {}
return await self.client.chat(
model=self.model_name,
messages=cohere_messages,
tools=self.tools or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
)

@staticmethod
def _process_response(response: ChatResponse) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
parts: list[ModelResponsePart] = []
if response.message.content is not None and len(response.message.content) > 0:
choice = response.message.content[0]
parts.append(TextPart(choice.text))
for c in response.message.tool_calls or []:
if c.function and c.function.name and c.function.arguments:
parts.append(
ToolCallPart.from_raw_args(
tool_name=c.function.name,
args=c.function.arguments,
tool_call_id=c.id,
)
)
return ModelResponse(parts=parts)

@classmethod
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
if isinstance(message, ModelRequest):
yield from cls._map_user_message(message)
elif isinstance(message, ModelResponse):
texts: list[str] = []
tool_calls: list[ToolCallV2] = []
for item in message.parts:
if isinstance(item, TextPart):
texts.append(item.content)
elif isinstance(item, ToolCallPart):
tool_calls.append(_map_tool_call(item))
else:
assert_never(item)
message_param = AssistantChatMessageV2(role='assistant')
if texts:
# Note: model responses from this model should only have one text item, so the following
# shouldn't merge multiple texts into one unless you switch models between runs:
message_param.content = '\n\n'.join(texts)
if tool_calls:
message_param.tool_calls = tool_calls
yield message_param
else:
assert_never(message)

@classmethod
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
for part in message.parts:
if isinstance(part, SystemPromptPart):
yield SystemChatMessageV2(role='system', content=part.content)
elif isinstance(part, UserPromptPart):
yield UserChatMessageV2(role='user', content=part.content)
elif isinstance(part, ToolReturnPart):
yield ToolChatMessageV2(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
content=part.model_response_str(),
)
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
yield UserChatMessageV2(role='user', content=part.model_response())
else:
yield ToolChatMessageV2(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
content=part.model_response(),
)
else:
assert_never(part)


def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
return ToolCallV2(
id=_guard_tool_call_id(t=t, model_source='Cohere'),
type='function',
function=ToolCallV2Function(
name=t.tool_name,
arguments=t.args_as_json_str(),
),
)


def _map_usage(response: ChatResponse) -> result.Usage:
usage = response.usage
if usage is None:
return result.Usage()
else:
details: dict[str, int] = {}
if usage.billed_units is not None:
if usage.billed_units.input_tokens:
details['input_tokens'] = int(usage.billed_units.input_tokens)
if usage.billed_units.output_tokens:
details['output_tokens'] = int(usage.billed_units.output_tokens)
if usage.billed_units.search_units:
details['search_units'] = int(usage.billed_units.search_units)
if usage.billed_units.classifications:
details['classifications'] = int(usage.billed_units.classifications)

request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
return result.Usage(
request_tokens=request_tokens,
response_tokens=response_tokens,
total_tokens=(request_tokens or 0) + (response_tokens or 0),
details=details,
)
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ModelSettings(TypedDict, total=False):
* Anthropic
* OpenAI
* Groq
* Cohere
"""

temperature: float
Expand All @@ -43,6 +44,7 @@ class ModelSettings(TypedDict, total=False):
* Anthropic
* OpenAI
* Groq
* Cohere
"""

top_p: float
Expand All @@ -58,6 +60,7 @@ class ModelSettings(TypedDict, total=False):
* Anthropic
* OpenAI
* Groq
* Cohere
"""

timeout: float | Timeout
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [

[project.optional-dependencies]
openai = ["openai>=1.54.3"]
cohere = ["cohere>=5.13.4"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
anthropic = ["anthropic>=0.40.0"]
groq = ["groq>=0.12.0"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers = [
]
requires-python = ">=3.9"

dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.15"]
dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral,cohere]==0.0.15"]

[project.urls]
Homepage = "https://ai.pydantic.dev"
Expand Down
Loading
Loading