Skip to content

Commit

Permalink
Fix bug with Agents for the enable_trace events (#254)
Browse files Browse the repository at this point in the history
The trace_log was either always empty or only showed the last one
because the events come through in succession and the last one was being
returned only. Changed the trace log variable to a list and append new
events then at the end of of the loop serialize the array to a string.

Added new unit tests for our different output types and to check the
existence of the trace log.

Fixes bug introduced in #244.

---------

Co-authored-by: John Baker <[email protected]>
  • Loading branch information
jdbaker01 and bakjohn authored Oct 24, 2024
1 parent 95794e1 commit fd3c820
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 76 deletions.
6 changes: 4 additions & 2 deletions libs/aws/langchain_aws/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def parse_agent_response(response: Any) -> OutputType:
response_text = ""
event_stream = response["completion"]
session_id = response["sessionId"]
trace_log = ""
trace_log_elements = []
for event in event_stream:
if "trace" in event:
trace_log = json.dumps(event["trace"])
trace_log_elements.append(event["trace"])

if "returnControl" in event:
response_text = json.dumps(event)
Expand All @@ -72,6 +72,8 @@ def parse_agent_response(response: Any) -> OutputType:
if "chunk" in event:
response_text = event["chunk"]["bytes"].decode("utf-8")

trace_log = json.dumps(trace_log_elements)

agent_finish = BedrockAgentFinish(
return_values={"output": response_text},
log=response_text,
Expand Down
65 changes: 65 additions & 0 deletions libs/aws/tests/unit_tests/agents/test_bedrock_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
from base64 import b64encode
from typing import Union

from langchain_aws.agents.base import (
BedrockAgentAction,
BedrockAgentFinish,
parse_agent_response,
)


class TestBedrockAgentResponseParser(unittest.TestCase):
def setUp(self) -> None:
self.maxDiff = None
# Mock successful response with function invocation
self.mock_success_return_of_control_response = {
"sessionId": "123",
"completion": [
{
"returnControl": {
"invocationInputs": [
{
"functionInvocationInput": {
"actionGroup": "price_tool_action_group",
"function": "PriceTool",
"parameters": [
{"name": "Symbol", "value": "XYZ"},
{"name": "Start_Date", "value": "20241020"},
{"name": "End_Date", "value": "20241020"},
],
}
}
]
}
}
],
}

self.mock_success_finish_response = {
"sessionId": "123",
"completion": [
{"chunk": {"bytes": b64encode("FAKE DATA HERE".encode())}},
{"trace": "This is a fake trace event."},
],
}

def test_parse_return_of_control_invocation(self) -> None:
response = self.mock_success_return_of_control_response
parsed_response: Union[list[BedrockAgentAction], BedrockAgentFinish]
parsed_response = parse_agent_response(response)
self.assertIsInstance(
parsed_response, list, "Expected a list of BedrockAgentAction."
)

def test_parse_finish_invocation(self) -> None:
response = self.mock_success_finish_response
parsed_response: Union[list[BedrockAgentAction], BedrockAgentFinish]
parsed_response = parse_agent_response(response)
# Type narrowing - now TypeScript knows parsed_response is BedrockAgentFinish
assert isinstance(parsed_response, BedrockAgentFinish)
assert parsed_response.trace_log is not None, "Expected trace_log"

self.assertGreater(
len(parsed_response.trace_log), 0, "Expected a trace log, none received."
)
Loading

0 comments on commit fd3c820

Please sign in to comment.