diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 701869a40..7a8a7b01e 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -55,6 +55,7 @@ def __init__( base_url: str | None = None, user: str | None = None, client: openai.AsyncClient | None = None, + token_usage_callback: Callable[[Dict[str, Any]], None] | None = None temperature: float | None = None, ) -> None: """ @@ -411,11 +412,12 @@ def chat( n=n, temperature=temperature, stream=True, + stream_options={"include_usage": True if self.token_usage_callback else False}, user=user, **opts, ) - return LLMStream(oai_stream=cmp, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) + return LLMStream(oai_stream=cmp, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, token_usage_callback=self.token_usage_callback) class LLMStream(llm.LLMStream): @@ -425,8 +427,10 @@ def __init__( oai_stream: Awaitable[openai.AsyncStream[ChatCompletionChunk]], chat_ctx: llm.ChatContext, fnc_ctx: llm.FunctionContext | None, + token_usage_callback: Callable[[Dict[str, Any]], None] | None = None ) -> None: super().__init__(chat_ctx=chat_ctx, fnc_ctx=fnc_ctx) + self.token_usage_callback = token_usage_callback self._awaitable_oai_stream = oai_stream self._oai_stream: openai.AsyncStream[ChatCompletionChunk] | None = None @@ -446,6 +450,9 @@ async def __anext__(self): self._oai_stream = await self._awaitable_oai_stream async for chunk in self._oai_stream: + if chunk.usage and self.token_usage_callback: + self.token_usage_callback(chunk.usage) + for choice in chunk.choices: chat_chunk = self._parse_choice(choice) if chat_chunk is not None: