diff --git a/docs/help.md b/docs/help.md index 9db540a3..ba5d3f0d 100644 --- a/docs/help.md +++ b/docs/help.md @@ -122,6 +122,7 @@ Options: --key TEXT API key to use --save TEXT Save prompt with this template name --async Run prompt asynchronously + -u, --usage Show token usage --help Show this message and exit. ``` @@ -292,6 +293,7 @@ Options: -m, --model TEXT Filter by model or model alias -q, --query TEXT Search for logs matching this string -t, --truncate Truncate long strings in output + -u, --usage Include token usage -r, --response Just output the last response -c, --current Show logs from the current conversation --cid, --conversation TEXT Show logs for this conversation ID diff --git a/docs/logging.md b/docs/logging.md index 63722e01..56c0379d 100644 --- a/docs/logging.md +++ b/docs/logging.md @@ -159,7 +159,10 @@ CREATE TABLE [responses] ( [response_json] TEXT, [conversation_id] TEXT REFERENCES [conversations]([id]), [duration_ms] INTEGER, - [datetime_utc] TEXT + [datetime_utc] TEXT, + [input_tokens] INTEGER, + [output_tokens] INTEGER, + [token_details] TEXT ); CREATE VIRTUAL TABLE [responses_fts] USING FTS5 ( [prompt], diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index f0efcfd1..9342d355 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -167,3 +167,19 @@ for prev_response in conversation.responses: The `response.text_or_raise()` method used there will return the text from the response or raise a `ValueError` exception if the response is an `AsyncResponse` instance that has not yet been fully resolved. This is a slightly weird hack to work around the common need to share logic for building up the `messages` list across both sync and async models. + +(advanced-model-plugins-usage)= + +## Tracking token usage + +Models that charge by the token should track the number of tokens used by each prompt. The ``response.set_usage()`` method can be used to record the number of tokens used by a response - these will then be made available through the Python API and logged to the SQLite database for command-line users. + +`response` here is the response object that is passed to `.execute()` as an argument. + +Call ``response.set_usage()`` at the end of your `.execute()` method. It accepts keyword arguments `input=`, `output=` and `details=` - all three are optional. `input` and `output` should be integers, and `details` should be a dictionary that provides additional information beyond the input and output token counts. + +This example logs 15 input tokens, 340 output tokens and notes that 37 tokens were cached: + +```python +response.set_usage(input=15, output=340, details={"cached": 37}) +``` diff --git a/llm/cli.py b/llm/cli.py index c75e0e3e..e0c8e47c 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -33,7 +33,7 @@ from .migrations import migrate from .plugins import pm, load_plugins -from .utils import mimetype_from_path, mimetype_from_string +from .utils import mimetype_from_path, mimetype_from_string, token_usage_string import base64 import httpx import pathlib @@ -203,6 +203,7 @@ def cli(): @click.option("--key", help="API key to use") @click.option("--save", help="Save prompt with this template name") @click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously") +@click.option("-u", "--usage", is_flag=True, help="Show token usage") def prompt( prompt, system, @@ -220,6 +221,7 @@ def prompt( key, save, async_, + usage, ): """ Execute a prompt @@ -426,14 +428,24 @@ async def inner(): except Exception as ex: raise click.ClickException(str(ex)) + if isinstance(response, AsyncResponse): + response = asyncio.run(response.to_sync_response()) + + if usage: + # Show token usage to stderr in yellow + click.echo( + click.style( + "Token usage: {}".format(response.token_usage()), fg="yellow", bold=True + ), + err=True, + ) + # Log to the database if (logs_on() or log) and not no_log: log_path = logs_db_path() (log_path.parent).mkdir(parents=True, exist_ok=True) db = sqlite_utils.Database(log_path) migrate(db) - if isinstance(response, AsyncResponse): - response = asyncio.run(response.to_sync_response()) response.log_to_db(db) @@ -754,6 +766,9 @@ def logs_turn_off(): responses.conversation_id, responses.duration_ms, responses.datetime_utc, + responses.input_tokens, + responses.output_tokens, + responses.token_details, conversations.name as conversation_name, conversations.model as conversation_model""" @@ -809,6 +824,7 @@ def logs_turn_off(): @click.option("-m", "--model", help="Filter by model or model alias") @click.option("-q", "--query", help="Search for logs matching this string") @click.option("-t", "--truncate", is_flag=True, help="Truncate long strings in output") +@click.option("-u", "--usage", is_flag=True, help="Include token usage") @click.option("-r", "--response", is_flag=True, help="Just output the last response") @click.option( "current_conversation", @@ -836,6 +852,7 @@ def logs_list( model, query, truncate, + usage, response, current_conversation, conversation_id, @@ -998,6 +1015,14 @@ def logs_list( ) click.echo("\n## Response:\n\n{}\n".format(row["response"])) + if usage: + token_usage = token_usage_string( + row["input_tokens"], + row["output_tokens"], + json.loads(row["token_details"]) if row["token_details"] else None, + ) + if token_usage: + click.echo("## Token usage:\n\n{}\n".format(token_usage)) @cli.group( diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 6234d5b1..ab33d1b4 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,6 +1,11 @@ from llm import AsyncModel, EmbeddingModel, Model, hookimpl import llm -from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client +from llm.utils import ( + dicts_to_table_string, + remove_dict_none_values, + logging_client, + simplify_usage_dict, +) import click import datetime import httpx @@ -391,6 +396,16 @@ def build_messages(self, prompt, conversation): messages.append({"role": "user", "content": attachment_message}) return messages + def set_usage(self, response, usage): + if not usage: + return + input_tokens = usage.pop("prompt_tokens") + output_tokens = usage.pop("completion_tokens") + usage.pop("total_tokens") + response.set_usage( + input=input_tokens, output=output_tokens, details=simplify_usage_dict(usage) + ) + def get_client(self, async_=False): kwargs = {} if self.api_base: @@ -445,6 +460,7 @@ def execute(self, prompt, stream, response, conversation=None): messages = self.build_messages(prompt, conversation) kwargs = self.build_kwargs(prompt, stream) client = self.get_client() + usage = None if stream: completion = client.chat.completions.create( model=self.model_name or self.model_id, @@ -455,6 +471,8 @@ def execute(self, prompt, stream, response, conversation=None): chunks = [] for chunk in completion: chunks.append(chunk) + if chunk.usage: + usage = chunk.usage.model_dump() try: content = chunk.choices[0].delta.content except IndexError: @@ -469,8 +487,10 @@ def execute(self, prompt, stream, response, conversation=None): stream=False, **kwargs, ) + usage = completion.usage.model_dump() response.response_json = remove_dict_none_values(completion.model_dump()) yield completion.choices[0].message.content + self.set_usage(response, usage) response._prompt_json = redact_data({"messages": messages}) @@ -493,6 +513,7 @@ async def execute( messages = self.build_messages(prompt, conversation) kwargs = self.build_kwargs(prompt, stream) client = self.get_client(async_=True) + usage = None if stream: completion = await client.chat.completions.create( model=self.model_name or self.model_id, @@ -502,6 +523,8 @@ async def execute( ) chunks = [] async for chunk in completion: + if chunk.usage: + usage = chunk.usage.model_dump() chunks.append(chunk) try: content = chunk.choices[0].delta.content @@ -518,7 +541,9 @@ async def execute( **kwargs, ) response.response_json = remove_dict_none_values(completion.model_dump()) + usage = completion.usage.model_dump() yield completion.choices[0].message.content + self.set_usage(response, usage) response._prompt_json = redact_data({"messages": messages}) diff --git a/llm/migrations.py b/llm/migrations.py index 91da6429..b8ac8b13 100644 --- a/llm/migrations.py +++ b/llm/migrations.py @@ -227,3 +227,10 @@ def m012_attachments_tables(db): ), pk=("response_id", "attachment_id"), ) + + +@migration +def m013_usage(db): + db["responses"].add_column("input_tokens", int) + db["responses"].add_column("output_tokens", int) + db["responses"].add_column("token_details", str) diff --git a/llm/models.py b/llm/models.py index c160798b..5bf9f11c 100644 --- a/llm/models.py +++ b/llm/models.py @@ -18,7 +18,7 @@ Set, Union, ) -from .utils import mimetype_from_path, mimetype_from_string +from .utils import mimetype_from_path, mimetype_from_string, token_usage_string from abc import ABC, abstractmethod import json from pydantic import BaseModel @@ -208,6 +208,20 @@ def __init__( self._start: Optional[float] = None self._end: Optional[float] = None self._start_utcnow: Optional[datetime.datetime] = None + self.input_tokens: Optional[int] = None + self.output_tokens: Optional[int] = None + self.token_details: Optional[dict] = None + + def set_usage( + self, + *, + input: Optional[int] = None, + output: Optional[int] = None, + details: Optional[dict] = None, + ): + self.input_tokens = input + self.output_tokens = output + self.token_details = details @classmethod def from_row(cls, db, row): @@ -246,6 +260,11 @@ def from_row(cls, db, row): ] return response + def token_usage(self) -> str: + return token_usage_string( + self.input_tokens, self.output_tokens, self.token_details + ) + def log_to_db(self, db): conversation = self.conversation if not conversation: @@ -272,11 +291,16 @@ def log_to_db(self, db): for key, value in dict(self.prompt.options).items() if value is not None }, - "response": self.text(), + "response": self.text_or_raise(), "response_json": self.json(), "conversation_id": conversation.id, "duration_ms": self.duration_ms(), "datetime_utc": self.datetime_utc(), + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "token_details": ( + json.dumps(self.token_details) if self.token_details else None + ), } db["responses"].insert(response) # Persist any attachments - loop through with index @@ -439,6 +463,9 @@ async def to_sync_response(self) -> Response: response._end = self._end response._start = self._start response._start_utcnow = self._start_utcnow + response.input_tokens = self.input_tokens + response.output_tokens = self.output_tokens + response.token_details = self.token_details return response @classmethod diff --git a/llm/utils.py b/llm/utils.py index d2618dd4..e9853185 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -127,3 +127,29 @@ def logging_client() -> httpx.Client: transport=_LogTransport(httpx.HTTPTransport()), event_hooks={"request": [_no_accept_encoding], "response": [_log_response]}, ) + + +def simplify_usage_dict(d): + # Recursively remove keys with value 0 and empty dictionaries + def remove_empty_and_zero(obj): + if isinstance(obj, dict): + cleaned = { + k: remove_empty_and_zero(v) + for k, v in obj.items() + if v != 0 and v != {} + } + return {k: v for k, v in cleaned.items() if v is not None and v != {}} + return obj + + return remove_empty_and_zero(d) or {} + + +def token_usage_string(input_tokens, output_tokens, token_details) -> str: + bits = [] + if input_tokens is not None: + bits.append(f"{format(input_tokens, ',')} input") + if output_tokens is not None: + bits.append(f"{format(output_tokens, ',')} output") + if token_details: + bits.append(json.dumps(token_details)) + return ", ".join(bits) diff --git a/tests/conftest.py b/tests/conftest.py index 6fb8bf75..447e1caa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,13 +66,17 @@ def enqueue(self, messages): def execute(self, prompt, stream, response, conversation): self.history.append((prompt, stream, response, conversation)) + gathered = [] while True: try: messages = self._queue.pop(0) - yield from messages + for message in messages: + gathered.append(message) + yield message break except IndexError: break + response.set_usage(input=len(prompt.prompt.split()), output=len(gathered)) class AsyncMockModel(llm.AsyncModel): diff --git a/tests/test_chat.py b/tests/test_chat.py index 285fa476..1a31f290 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -62,6 +62,9 @@ def test_chat_basic(mock_model, logs_db): "conversation_id": conversation_id, "duration_ms": ANY, "datetime_utc": ANY, + "input_tokens": 1, + "output_tokens": 1, + "token_details": None, }, { "id": ANY, @@ -75,6 +78,9 @@ def test_chat_basic(mock_model, logs_db): "conversation_id": conversation_id, "duration_ms": ANY, "datetime_utc": ANY, + "input_tokens": 2, + "output_tokens": 1, + "token_details": None, }, ] # Now continue that conversation @@ -116,6 +122,9 @@ def test_chat_basic(mock_model, logs_db): "conversation_id": conversation_id, "duration_ms": ANY, "datetime_utc": ANY, + "input_tokens": 1, + "output_tokens": 1, + "token_details": None, } ] @@ -153,6 +162,9 @@ def test_chat_system(mock_model, logs_db): "conversation_id": ANY, "duration_ms": ANY, "datetime_utc": ANY, + "input_tokens": 1, + "output_tokens": 1, + "token_details": None, } ] @@ -181,6 +193,9 @@ def test_chat_options(mock_model, logs_db): "conversation_id": ANY, "duration_ms": ANY, "datetime_utc": ANY, + "input_tokens": 1, + "output_tokens": 1, + "token_details": None, } ] diff --git a/tests/test_cli_openai_models.py b/tests/test_cli_openai_models.py index b65ad078..3d0a7c16 100644 --- a/tests/test_cli_openai_models.py +++ b/tests/test_cli_openai_models.py @@ -147,7 +147,8 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype): @pytest.mark.parametrize("async_", (False, True)) -def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_): +@pytest.mark.parametrize("usage", (None, "-u", "--usage")) +def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_, usage): user_path = tmpdir / "user_dir" log_db = user_path / "logs.db" monkeypatch.setenv("LLM_USER_PATH", str(user_path)) @@ -173,21 +174,25 @@ def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_): } ], "usage": { - "prompt_tokens": 10, - "completion_tokens": 2, + "prompt_tokens": 1000, + "completion_tokens": 2000, "total_tokens": 12, }, "system_fingerprint": "fp_49254d0e9b", }, headers={"Content-Type": "application/json"}, ) - runner = CliRunner() + runner = CliRunner(mix_stderr=False) args = ["-m", "gpt-4o-mini", "--key", "x", "--no-stream"] + if usage: + args.append(usage) if async_: args.append("--async") result = runner.invoke(cli, args, catch_exceptions=False) assert result.exit_code == 0 assert result.output == "Ho ho ho\n" + if usage: + assert result.stderr == "Token usage: 1,000 input, 2,000 output\n" # Confirm it was correctly logged assert log_db.exists() db = sqlite_utils.Database(str(log_db)) diff --git a/tests/test_llm.py b/tests/test_llm.py index 0e54cc91..b83ff842 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -37,6 +37,8 @@ def log_path(user_path): "model": "davinci", "datetime_utc": (start + datetime.timedelta(seconds=i)).isoformat(), "conversation_id": "abc123", + "input_tokens": 2, + "output_tokens": 5, } for i in range(100) ) @@ -46,9 +48,12 @@ def log_path(user_path): datetime_re = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") -def test_logs_text(log_path): +@pytest.mark.parametrize("usage", (False, True)) +def test_logs_text(log_path, usage): runner = CliRunner() args = ["logs", "-p", str(log_path)] + if usage: + args.append("-u") result = runner.invoke(cli, args, catch_exceptions=False) assert result.exit_code == 0 output = result.output @@ -64,18 +69,24 @@ def test_logs_text(log_path): "system\n\n" "## Response:\n\n" "response\n\n" + ) + ("## Token usage:\n\n2 input, 5 output\n\n" if usage else "") + ( "# YYYY-MM-DDTHH:MM:SS conversation: abc123\n\n" "Model: **davinci**\n\n" "## Prompt:\n\n" "prompt\n\n" "## Response:\n\n" "response\n\n" + ) + ( + "## Token usage:\n\n2 input, 5 output\n\n" if usage else "" + ) + ( "# YYYY-MM-DDTHH:MM:SS conversation: abc123\n\n" "Model: **davinci**\n\n" "## Prompt:\n\n" "prompt\n\n" "## Response:\n\n" "response\n\n" + ) + ( + "## Token usage:\n\n2 input, 5 output\n\n" if usage else "" ) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 1c68de93..d1da5571 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -17,6 +17,9 @@ "conversation_id": str, "duration_ms": int, "datetime_utc": str, + "input_tokens": int, + "output_tokens": int, + "token_details": str, } diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..85ed54ae --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,42 @@ +import pytest +from llm.utils import simplify_usage_dict + + +@pytest.mark.parametrize( + "input_data,expected_output", + [ + ( + { + "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 1, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + {"completion_tokens_details": {"audio_tokens": 1}}, + ), + ( + { + "details": {"tokens": 5, "audio_tokens": 2}, + "more_details": {"accepted_tokens": 3}, + }, + { + "details": {"tokens": 5, "audio_tokens": 2}, + "more_details": {"accepted_tokens": 3}, + }, + ), + ({"details": {"tokens": 0, "audio_tokens": 0}, "more_details": {}}, {}), + ({"level1": {"level2": {"value": 0, "another_value": {}}}}, {}), + ( + { + "level1": {"level2": {"value": 0, "another_value": 1}}, + "level3": {"empty_dict": {}, "valid_token": 10}, + }, + {"level1": {"level2": {"another_value": 1}}, "level3": {"valid_token": 10}}, + ), + ], +) +def test_simplify_usage_dict(input_data, expected_output): + assert simplify_usage_dict(input_data) == expected_output