Skip to content

Commit

Permalink
Async models
Browse files Browse the repository at this point in the history
* Tip about pytest --record-mode once

Plus mechanism for setting API key during tests with PYTEST_ANTHROPIC_API_KEY

* Async support for Claude models

Closes #25
Refs simonw/llm#507
Refs simonw/llm#613

* Depend on llm>=0.18a0, refs #25
  • Loading branch information
simonw authored Nov 14, 2024
1 parent 40b7dca commit 8dc7785
Show file tree
Hide file tree
Showing 6 changed files with 770 additions and 22 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,15 @@ To run the tests:
```bash
pytest
```

This project uses [pytest-recording](https://github.com/kiwicom/pytest-recording) to record Anthropic API responses for the tests.

If you add a new test that calls the API you can capture the API response like this:
```bash
PYTEST_ANTHROPIC_API_KEY="$(llm keys get claude)" pytest --record-mode once
```
You will need to have stored a valid Anthropic API key using this command first:
```bash
llm keys set claude
# Paste key here
```
87 changes: 69 additions & 18 deletions llm_claude_3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from anthropic import Anthropic
from anthropic import Anthropic, AsyncAnthropic
import llm
from pydantic import Field, field_validator, model_validator
from typing import Optional, List
Expand All @@ -7,19 +7,42 @@
@llm.hookimpl
def register_models(register):
# https://docs.anthropic.com/claude/docs/models-overview
register(ClaudeMessages("claude-3-opus-20240229"))
register(ClaudeMessages("claude-3-opus-latest"), aliases=("claude-3-opus",))
register(ClaudeMessages("claude-3-sonnet-20240229"), aliases=("claude-3-sonnet",))
register(ClaudeMessages("claude-3-haiku-20240307"), aliases=("claude-3-haiku",))
register(
ClaudeMessages("claude-3-opus-20240229"),
AsyncClaudeMessages("claude-3-opus-20240229"),
),
register(
ClaudeMessages("claude-3-opus-latest"),
AsyncClaudeMessages("claude-3-opus-latest"),
aliases=("claude-3-opus",),
)
register(
ClaudeMessages("claude-3-sonnet-20240229"),
AsyncClaudeMessages("claude-3-sonnet-20240229"),
aliases=("claude-3-sonnet",),
)
register(
ClaudeMessages("claude-3-haiku-20240307"),
AsyncClaudeMessages("claude-3-haiku-20240307"),
aliases=("claude-3-haiku",),
)
# 3.5 models
register(ClaudeMessagesLong("claude-3-5-sonnet-20240620"))
register(ClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True)),
register(
ClaudeMessagesLong("claude-3-5-sonnet-20240620"),
AsyncClaudeMessagesLong("claude-3-5-sonnet-20240620"),
)
register(
ClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True),
AsyncClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True),
)
register(
ClaudeMessagesLong("claude-3-5-sonnet-latest", supports_pdf=True),
AsyncClaudeMessagesLong("claude-3-5-sonnet-latest", supports_pdf=True),
aliases=("claude-3.5-sonnet", "claude-3.5-sonnet-latest"),
)
register(
ClaudeMessagesLong("claude-3-5-haiku-latest", supports_images=False),
AsyncClaudeMessagesLong("claude-3-5-haiku-latest", supports_images=False),
aliases=("claude-3.5-haiku",),
)

Expand Down Expand Up @@ -86,7 +109,13 @@ def validate_temperature_top_p(self):
return self


class ClaudeMessages(llm.Model):
long_field = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096 * 2,
)


class _Shared:
needs_key = "claude"
key_env_var = "ANTHROPIC_API_KEY"
can_stream = True
Expand Down Expand Up @@ -178,9 +207,7 @@ def build_messages(self, prompt, conversation) -> List[dict]:
messages.append({"role": "user", "content": prompt.prompt})
return messages

def execute(self, prompt, stream, response, conversation):
client = Anthropic(api_key=self.get_key())

def build_kwargs(self, prompt, conversation):
kwargs = {
"model": self.claude_model_id,
"messages": self.build_messages(prompt, conversation),
Expand All @@ -202,7 +229,17 @@ def execute(self, prompt, stream, response, conversation):

if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
return kwargs

def __str__(self):
return "Anthropic Messages: {}".format(self.model_id)


class ClaudeMessages(_Shared, llm.Model):

def execute(self, prompt, stream, response, conversation):
client = Anthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
if stream:
with client.messages.stream(**kwargs) as stream:
for text in stream.text_stream:
Expand All @@ -214,13 +251,27 @@ def execute(self, prompt, stream, response, conversation):
yield completion.content[0].text
response.response_json = completion.model_dump()

def __str__(self):
return "Anthropic Messages: {}".format(self.model_id)


class ClaudeMessagesLong(ClaudeMessages):
class Options(ClaudeOptions):
max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096 * 2,
)
max_tokens: Optional[int] = long_field


class AsyncClaudeMessages(_Shared, llm.AsyncModel):
async def execute(self, prompt, stream, response, conversation):
client = AsyncAnthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
if stream:
async with client.messages.stream(**kwargs) as stream_obj:
async for text in stream_obj.text_stream:
yield text
response.response_json = (await stream_obj.get_final_message()).model_dump()
else:
completion = await client.messages.create(**kwargs)
yield completion.content[0].text
response.response_json = completion.model_dump()


class AsyncClaudeMessagesLong(AsyncClaudeMessages):
class Options(ClaudeOptions):
max_tokens: Optional[int] = long_field
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"llm>=0.17",
"llm>=0.18a0",
"anthropic>=0.39.0",
]

Expand All @@ -23,4 +23,4 @@ CI = "https://github.com/simonw/llm-claude-3/actions"
claude_3 = "llm_claude_3"

[project.optional-dependencies]
test = ["pytest", "pytest-recording"]
test = ["pytest", "pytest-recording", "pytest-asyncio"]
Loading

0 comments on commit 8dc7785

Please sign in to comment.