Skip to content

Commit

Permalink
Log input tokens, output tokens and token details (#642)
Browse files Browse the repository at this point in the history
* Store input_tokens, output_tokens, token_details on Response, closes #610
* llm prompt -u/--usage option
* llm logs -u/--usage option
* Docs on tracking token usage in plugins
* OpenAI default plugin logs usage
  • Loading branch information
simonw authored Nov 20, 2024
1 parent 4a059d7 commit cfb10f4
Show file tree
Hide file tree
Showing 14 changed files with 224 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docs/help.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Options:
--key TEXT API key to use
--save TEXT Save prompt with this template name
--async Run prompt asynchronously
-u, --usage Show token usage
--help Show this message and exit.
```

Expand Down Expand Up @@ -292,6 +293,7 @@ Options:
-m, --model TEXT Filter by model or model alias
-q, --query TEXT Search for logs matching this string
-t, --truncate Truncate long strings in output
-u, --usage Include token usage
-r, --response Just output the last response
-c, --current Show logs from the current conversation
--cid, --conversation TEXT Show logs for this conversation ID
Expand Down
5 changes: 4 additions & 1 deletion docs/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ CREATE TABLE [responses] (
[response_json] TEXT,
[conversation_id] TEXT REFERENCES [conversations]([id]),
[duration_ms] INTEGER,
[datetime_utc] TEXT
[datetime_utc] TEXT,
[input_tokens] INTEGER,
[output_tokens] INTEGER,
[token_details] TEXT
);
CREATE VIRTUAL TABLE [responses_fts] USING FTS5 (
[prompt],
Expand Down
16 changes: 16 additions & 0 deletions docs/plugins/advanced-model-plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,19 @@ for prev_response in conversation.responses:
The `response.text_or_raise()` method used there will return the text from the response or raise a `ValueError` exception if the response is an `AsyncResponse` instance that has not yet been fully resolved.

This is a slightly weird hack to work around the common need to share logic for building up the `messages` list across both sync and async models.

(advanced-model-plugins-usage)=

## Tracking token usage

Models that charge by the token should track the number of tokens used by each prompt. The ``response.set_usage()`` method can be used to record the number of tokens used by a response - these will then be made available through the Python API and logged to the SQLite database for command-line users.

`response` here is the response object that is passed to `.execute()` as an argument.

Call ``response.set_usage()`` at the end of your `.execute()` method. It accepts keyword arguments `input=`, `output=` and `details=` - all three are optional. `input` and `output` should be integers, and `details` should be a dictionary that provides additional information beyond the input and output token counts.

This example logs 15 input tokens, 340 output tokens and notes that 37 tokens were cached:

```python
response.set_usage(input=15, output=340, details={"cached": 37})
```
31 changes: 28 additions & 3 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from .migrations import migrate
from .plugins import pm, load_plugins
from .utils import mimetype_from_path, mimetype_from_string
from .utils import mimetype_from_path, mimetype_from_string, token_usage_string
import base64
import httpx
import pathlib
Expand Down Expand Up @@ -203,6 +203,7 @@ def cli():
@click.option("--key", help="API key to use")
@click.option("--save", help="Save prompt with this template name")
@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously")
@click.option("-u", "--usage", is_flag=True, help="Show token usage")
def prompt(
prompt,
system,
Expand All @@ -220,6 +221,7 @@ def prompt(
key,
save,
async_,
usage,
):
"""
Execute a prompt
Expand Down Expand Up @@ -426,14 +428,24 @@ async def inner():
except Exception as ex:
raise click.ClickException(str(ex))

if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())

if usage:
# Show token usage to stderr in yellow
click.echo(
click.style(
"Token usage: {}".format(response.token_usage()), fg="yellow", bold=True
),
err=True,
)

# Log to the database
if (logs_on() or log) and not no_log:
log_path = logs_db_path()
(log_path.parent).mkdir(parents=True, exist_ok=True)
db = sqlite_utils.Database(log_path)
migrate(db)
if isinstance(response, AsyncResponse):
response = asyncio.run(response.to_sync_response())
response.log_to_db(db)


Expand Down Expand Up @@ -754,6 +766,9 @@ def logs_turn_off():
responses.conversation_id,
responses.duration_ms,
responses.datetime_utc,
responses.input_tokens,
responses.output_tokens,
responses.token_details,
conversations.name as conversation_name,
conversations.model as conversation_model"""

Expand Down Expand Up @@ -809,6 +824,7 @@ def logs_turn_off():
@click.option("-m", "--model", help="Filter by model or model alias")
@click.option("-q", "--query", help="Search for logs matching this string")
@click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output")
@click.option("-u", "--usage", is_flag=True, help="Include token usage")
@click.option("-r", "--response", is_flag=True, help="Just output the last response")
@click.option(
"current_conversation",
Expand Down Expand Up @@ -836,6 +852,7 @@ def logs_list(
model,
query,
truncate,
usage,
response,
current_conversation,
conversation_id,
Expand Down Expand Up @@ -998,6 +1015,14 @@ def logs_list(
)

click.echo("\n## Response:\n\n{}\n".format(row["response"]))
if usage:
token_usage = token_usage_string(
row["input_tokens"],
row["output_tokens"],
json.loads(row["token_details"]) if row["token_details"] else None,
)
if token_usage:
click.echo("## Token usage:\n\n{}\n".format(token_usage))


@cli.group(
Expand Down
27 changes: 26 additions & 1 deletion llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from llm import AsyncModel, EmbeddingModel, Model, hookimpl
import llm
from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client
from llm.utils import (
dicts_to_table_string,
remove_dict_none_values,
logging_client,
simplify_usage_dict,
)
import click
import datetime
import httpx
Expand Down Expand Up @@ -391,6 +396,16 @@ def build_messages(self, prompt, conversation):
messages.append({"role": "user", "content": attachment_message})
return messages

def set_usage(self, response, usage):
if not usage:
return
input_tokens = usage.pop("prompt_tokens")
output_tokens = usage.pop("completion_tokens")
usage.pop("total_tokens")
response.set_usage(
input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage)
)

def get_client(self, async_=False):
kwargs = {}
if self.api_base:
Expand Down Expand Up @@ -445,6 +460,7 @@ def execute(self, prompt, stream, response, conversation=None):
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
usage = None
if stream:
completion = client.chat.completions.create(
model=self.model_name or self.model_id,
Expand All @@ -455,6 +471,8 @@ def execute(self, prompt, stream, response, conversation=None):
chunks = []
for chunk in completion:
chunks.append(chunk)
if chunk.usage:
usage = chunk.usage.model_dump()
try:
content = chunk.choices[0].delta.content
except IndexError:
Expand All @@ -469,8 +487,10 @@ def execute(self, prompt, stream, response, conversation=None):
stream=False,
**kwargs,
)
usage = completion.usage.model_dump()
response.response_json = remove_dict_none_values(completion.model_dump())
yield completion.choices[0].message.content
self.set_usage(response, usage)
response._prompt_json = redact_data({"messages": messages})


Expand All @@ -493,6 +513,7 @@ async def execute(
messages = self.build_messages(prompt, conversation)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client(async_=True)
usage = None
if stream:
completion = await client.chat.completions.create(
model=self.model_name or self.model_id,
Expand All @@ -502,6 +523,8 @@ async def execute(
)
chunks = []
async for chunk in completion:
if chunk.usage:
usage = chunk.usage.model_dump()
chunks.append(chunk)
try:
content = chunk.choices[0].delta.content
Expand All @@ -518,7 +541,9 @@ async def execute(
**kwargs,
)
response.response_json = remove_dict_none_values(completion.model_dump())
usage = completion.usage.model_dump()
yield completion.choices[0].message.content
self.set_usage(response, usage)
response._prompt_json = redact_data({"messages": messages})


Expand Down
7 changes: 7 additions & 0 deletions llm/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,10 @@ def m012_attachments_tables(db):
),
pk=("response_id", "attachment_id"),
)


@migration
def m013_usage(db):
db["responses"].add_column("input_tokens", int)
db["responses"].add_column("output_tokens", int)
db["responses"].add_column("token_details", str)
31 changes: 29 additions & 2 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Set,
Union,
)
from .utils import mimetype_from_path, mimetype_from_string
from .utils import mimetype_from_path, mimetype_from_string, token_usage_string
from abc import ABC, abstractmethod
import json
from pydantic import BaseModel
Expand Down Expand Up @@ -208,6 +208,20 @@ def __init__(
self._start: Optional[float] = None
self._end: Optional[float] = None
self._start_utcnow: Optional[datetime.datetime] = None
self.input_tokens: Optional[int] = None
self.output_tokens: Optional[int] = None
self.token_details: Optional[dict] = None

def set_usage(
self,
*,
input: Optional[int] = None,
output: Optional[int] = None,
details: Optional[dict] = None,
):
self.input_tokens = input
self.output_tokens = output
self.token_details = details

@classmethod
def from_row(cls, db, row):
Expand Down Expand Up @@ -246,6 +260,11 @@ def from_row(cls, db, row):
]
return response

def token_usage(self) -> str:
return token_usage_string(
self.input_tokens, self.output_tokens, self.token_details
)

def log_to_db(self, db):
conversation = self.conversation
if not conversation:
Expand All @@ -272,11 +291,16 @@ def log_to_db(self, db):
for key, value in dict(self.prompt.options).items()
if value is not None
},
"response": self.text(),
"response": self.text_or_raise(),
"response_json": self.json(),
"conversation_id": conversation.id,
"duration_ms": self.duration_ms(),
"datetime_utc": self.datetime_utc(),
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"token_details": (
json.dumps(self.token_details) if self.token_details else None
),
}
db["responses"].insert(response)
# Persist any attachments - loop through with index
Expand Down Expand Up @@ -439,6 +463,9 @@ async def to_sync_response(self) -> Response:
response._end = self._end
response._start = self._start
response._start_utcnow = self._start_utcnow
response.input_tokens = self.input_tokens
response.output_tokens = self.output_tokens
response.token_details = self.token_details
return response

@classmethod
Expand Down
26 changes: 26 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,29 @@ def logging_client() -> httpx.Client:
transport=_LogTransport(httpx.HTTPTransport()),
event_hooks={"request": [_no_accept_encoding], "response": [_log_response]},
)


def simplify_usage_dict(d):
# Recursively remove keys with value 0 and empty dictionaries
def remove_empty_and_zero(obj):
if isinstance(obj, dict):
cleaned = {
k: remove_empty_and_zero(v)
for k, v in obj.items()
if v != 0 and v != {}
}
return {k: v for k, v in cleaned.items() if v is not None and v != {}}
return obj

return remove_empty_and_zero(d) or {}


def token_usage_string(input_tokens, output_tokens, token_details) -> str:
bits = []
if input_tokens is not None:
bits.append(f"{format(input_tokens, ',')} input")
if output_tokens is not None:
bits.append(f"{format(output_tokens, ',')} output")
if token_details:
bits.append(json.dumps(token_details))
return ", ".join(bits)
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ def enqueue(self, messages):

def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
gathered = []
while True:
try:
messages = self._queue.pop(0)
yield from messages
for message in messages:
gathered.append(message)
yield message
break
except IndexError:
break
response.set_usage(input=len(prompt.prompt.split()), output=len(gathered))


class AsyncMockModel(llm.AsyncModel):
Expand Down
Loading

0 comments on commit cfb10f4

Please sign in to comment.