From 660bea934f4b3e9ab773041edd628e09045d8de4 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 15 Oct 2024 13:01:52 -0700 Subject: [PATCH] Fixes token logging in callbacks when streaming=True is used. --- libs/aws/langchain_aws/chat_models/bedrock.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 23a842a2..54d34ffb 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -477,14 +477,19 @@ def _stream( **kwargs, ): if isinstance(chunk, AIMessageChunk): - yield ChatGenerationChunk(message=chunk) + generation_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk else: delta = chunk.text if generation_info := chunk.generation_info: usage_metadata = generation_info.pop("usage_metadata", None) else: usage_metadata = None - yield ChatGenerationChunk( + generation_chunk = ChatGenerationChunk( message=AIMessageChunk( content=delta, response_metadata=chunk.generation_info, @@ -493,6 +498,11 @@ def _stream( if chunk.generation_info is not None else AIMessageChunk(content=delta) ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk def _generate( self,