Skip to content

Commit

Permalink
Getting llm_output (containing usage and stop_reason from both st…
Browse files Browse the repository at this point in the history
…reaming and non-streaming generation and passing into the ChatResult
  • Loading branch information
NAPTlME committed Apr 22, 2024
1 parent 8006cb5 commit 1eeb60f
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra

from langchain_aws.llms.bedrock import BedrockBase
from langchain_aws.llms.bedrock import BedrockBase, _combine_generation_info_for_llm_result
from langchain_aws.utils import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
Expand Down Expand Up @@ -331,11 +331,14 @@ def _generate(
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}
usage_info: Dict[str, Any] = {}
llm_output: Dict[str, Any] = {}
provider_stop_reason_code = self.provider_stop_reason_key_map.get(self._get_provider(), "stop_reason")
if self.streaming:
response_metadata: list[Dict[str, Any]] = []
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
response_metadata.append(chunk.message.response_metadata)
llm_output = _combine_generation_info_for_llm_result(response_metadata, provider_stop_reason_code)
else:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
Expand All @@ -353,7 +356,7 @@ def _generate(
if stop:
params["stop_sequences"] = stop

completion, usage_info = self._prepare_input_and_invoke(
completion, llm_output = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
Expand All @@ -362,13 +365,12 @@ def _generate(
**params,
)

llm_output["usage"] = usage_info

llm_output["model_id"] = self.model_id
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=completion, additional_kwargs={"usage": usage_info}
content=completion, additional_kwargs=llm_output
)
)
],
Expand Down

0 comments on commit 1eeb60f

Please sign in to comment.