diff --git a/llm_mistral.py b/llm_mistral.py index 9e8f41b..0a64b7a 100644 --- a/llm_mistral.py +++ b/llm_mistral.py @@ -241,6 +241,12 @@ def build_body(self, prompt, messages): body["random_seed"] = prompt.options.random_seed return body + def set_usage(self, response, usage): + response.set_usage( + input=usage["prompt_tokens"], + output=usage["completion_tokens"], + ) + class Mistral(_Shared, llm.Model): def execute(self, prompt, stream, response, conversation): @@ -281,13 +287,19 @@ def execute(self, prompt, stream, response, conversation): raise click.ClickException( f"{event_source.response.status_code}: {type} - {message}" ) + usage = None event_source.response.raise_for_status() for sse in event_source.iter_sse(): if sse.data != "[DONE]": try: - yield sse.json()["choices"][0]["delta"]["content"] + event = sse.json() + if "usage" in event: + usage = event["usage"] + yield event["choices"][0]["delta"]["content"] except KeyError: pass + if usage: + self.set_usage(response, usage) else: with httpx.Client() as client: api_response = client.post( @@ -302,7 +314,11 @@ def execute(self, prompt, stream, response, conversation): ) api_response.raise_for_status() yield api_response.json()["choices"][0]["message"]["content"] - response.response_json = api_response.json() + details = api_response.json() + usage = details.pop("usage", None) + response.response_json = details + if usage: + self.set_usage(response, usage) class AsyncMistral(_Shared, llm.AsyncModel): @@ -345,12 +361,18 @@ async def execute(self, prompt, stream, response, conversation): f"{event_source.response.status_code}: {type} - {message}" ) event_source.response.raise_for_status() + usage = None async for sse in event_source.aiter_sse(): if sse.data != "[DONE]": try: - yield sse.json()["choices"][0]["delta"]["content"] + event = sse.json() + if "usage" in event: + usage = event["usage"] + yield event["choices"][0]["delta"]["content"] except KeyError: pass + if usage: + self.set_usage(response, usage) else: async with httpx.AsyncClient() as client: api_response = await client.post( @@ -365,7 +387,11 @@ async def execute(self, prompt, stream, response, conversation): ) api_response.raise_for_status() yield api_response.json()["choices"][0]["message"]["content"] - response.response_json = api_response.json() + details = api_response.json() + usage = details.pop("usage", None) + response.response_json = details + if usage: + self.set_usage(response, usage) class MistralEmbed(llm.EmbeddingModel): diff --git a/pyproject.toml b/pyproject.toml index 3f17390..20605cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "llm>=0.18", + "llm>=0.19a0", "httpx", "httpx-sse", ]