From 08d4376a2e863a5a8760aea9a48895eb92563546 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 19 Nov 2024 18:41:12 -0800 Subject: [PATCH] llm prompt -u/--usage option Refs https://github.com/simonw/llm/issues/610#issuecomment-2487217026 --- docs/help.md | 1 + llm/cli.py | 16 ++++++++++++++-- llm/models.py | 10 ++++++++++ tests/test_cli_openai_models.py | 9 +++++++-- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/docs/help.md b/docs/help.md index 9db540a3..34f429b8 100644 --- a/docs/help.md +++ b/docs/help.md @@ -122,6 +122,7 @@ Options: --key TEXT API key to use --save TEXT Save prompt with this template name --async Run prompt asynchronously + -u, --usage Show token usage --help Show this message and exit. ``` diff --git a/llm/cli.py b/llm/cli.py index f24b24d7..22699a92 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -203,6 +203,7 @@ def cli(): @click.option("--key", help="API key to use") @click.option("--save", help="Save prompt with this template name") @click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously") +@click.option("-u", "--usage", is_flag=True, help="Show token usage") def prompt( prompt, system, @@ -220,6 +221,7 @@ def prompt( key, save, async_, + usage, ): """ Execute a prompt @@ -426,14 +428,24 @@ async def inner(): except Exception as ex: raise click.ClickException(str(ex)) + if isinstance(response, AsyncResponse): + response = asyncio.run(response.to_sync_response()) + + if usage: + # Show token usage to stderr in yellow + click.echo( + click.style( + "Token usage: {}".format(response.token_usage()), fg="yellow", bold=True + ), + err=True, + ) + # Log to the database if (logs_on() or log) and not no_log: log_path = logs_db_path() (log_path.parent).mkdir(parents=True, exist_ok=True) db = sqlite_utils.Database(log_path) migrate(db) - if isinstance(response, AsyncResponse): - response = asyncio.run(response.to_sync_response()) response.log_to_db(db) diff --git a/llm/models.py b/llm/models.py index 0d6a61f6..759f9f76 100644 --- a/llm/models.py +++ b/llm/models.py @@ -260,6 +260,16 @@ def from_row(cls, db, row): ] return response + def token_usage(self) -> str: + bits = [] + if self.input_tokens is not None: + bits.append(f"{self.input_tokens} input") + if self.output_tokens is not None: + bits.append(f"{self.output_tokens} output") + if self.token_details: + bits.append(json.dumps(self.token_details)) + return ", ".join(bits) + def log_to_db(self, db): conversation = self.conversation if not conversation: diff --git a/tests/test_cli_openai_models.py b/tests/test_cli_openai_models.py index b65ad078..8bd45338 100644 --- a/tests/test_cli_openai_models.py +++ b/tests/test_cli_openai_models.py @@ -147,7 +147,8 @@ def test_only_gpt4_audio_preview_allows_mp3_or_wav(httpx_mock, model, filetype): @pytest.mark.parametrize("async_", (False, True)) -def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_): +@pytest.mark.parametrize("usage", (None, "-u", "--usage")) +def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_, usage): user_path = tmpdir / "user_dir" log_db = user_path / "logs.db" monkeypatch.setenv("LLM_USER_PATH", str(user_path)) @@ -181,13 +182,17 @@ def test_gpt4o_mini_sync_and_async(monkeypatch, tmpdir, httpx_mock, async_): }, headers={"Content-Type": "application/json"}, ) - runner = CliRunner() + runner = CliRunner(mix_stderr=False) args = ["-m", "gpt-4o-mini", "--key", "x", "--no-stream"] + if usage: + args.append(usage) if async_: args.append("--async") result = runner.invoke(cli, args, catch_exceptions=False) assert result.exit_code == 0 assert result.output == "Ho ho ho\n" + if usage: + assert result.stderr == "Token usage: 10 input, 2 output\n" # Confirm it was correctly logged assert log_db.exists() db = sqlite_utils.Database(str(log_db))