Skip to content

Commit

Permalink
Use response.set_usage(), closes #15
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 20, 2024
1 parent f590da3 commit 8da8e35
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
34 changes: 30 additions & 4 deletions llm_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ def build_body(self, prompt, messages):
body["random_seed"] = prompt.options.random_seed
return body

def set_usage(self, response, usage):
response.set_usage(
input=usage["prompt_tokens"],
output=usage["completion_tokens"],
)


class Mistral(_Shared, llm.Model):
def execute(self, prompt, stream, response, conversation):
Expand Down Expand Up @@ -281,13 +287,19 @@ def execute(self, prompt, stream, response, conversation):
raise click.ClickException(
f"{event_source.response.status_code}: {type} - {message}"
)
usage = None
event_source.response.raise_for_status()
for sse in event_source.iter_sse():
if sse.data != "[DONE]":
try:
yield sse.json()["choices"][0]["delta"]["content"]
event = sse.json()
if "usage" in event:
usage = event["usage"]
yield event["choices"][0]["delta"]["content"]
except KeyError:
pass
if usage:
self.set_usage(response, usage)
else:
with httpx.Client() as client:
api_response = client.post(
Expand All @@ -302,7 +314,11 @@ def execute(self, prompt, stream, response, conversation):
)
api_response.raise_for_status()
yield api_response.json()["choices"][0]["message"]["content"]
response.response_json = api_response.json()
details = api_response.json()
usage = details.pop("usage", None)
response.response_json = details
if usage:
self.set_usage(response, usage)


class AsyncMistral(_Shared, llm.AsyncModel):
Expand Down Expand Up @@ -345,12 +361,18 @@ async def execute(self, prompt, stream, response, conversation):
f"{event_source.response.status_code}: {type} - {message}"
)
event_source.response.raise_for_status()
usage = None
async for sse in event_source.aiter_sse():
if sse.data != "[DONE]":
try:
yield sse.json()["choices"][0]["delta"]["content"]
event = sse.json()
if "usage" in event:
usage = event["usage"]
yield event["choices"][0]["delta"]["content"]
except KeyError:
pass
if usage:
self.set_usage(response, usage)
else:
async with httpx.AsyncClient() as client:
api_response = await client.post(
Expand All @@ -365,7 +387,11 @@ async def execute(self, prompt, stream, response, conversation):
)
api_response.raise_for_status()
yield api_response.json()["choices"][0]["message"]["content"]
response.response_json = api_response.json()
details = api_response.json()
usage = details.pop("usage", None)
response.response_json = details
if usage:
self.set_usage(response, usage)


class MistralEmbed(llm.EmbeddingModel):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"llm>=0.18",
"llm>=0.19a0",
"httpx",
"httpx-sse",
]
Expand Down

0 comments on commit 8da8e35

Please sign in to comment.