From 211e0ecfa0be7ce1448fbeb6d77909913f6f7723 Mon Sep 17 00:00:00 2001 From: John Baker Date: Tue, 5 Nov 2024 16:37:08 -0800 Subject: [PATCH 1/7] Add the enable_human_input flag to create agent. --- libs/aws/langchain_aws/agents/base.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/libs/aws/langchain_aws/agents/base.py b/libs/aws/langchain_aws/agents/base.py index eddd49f7..e9044b8f 100644 --- a/libs/aws/langchain_aws/agents/base.py +++ b/libs/aws/langchain_aws/agents/base.py @@ -193,7 +193,7 @@ def _get_action_group_and_function_names(tool: BaseTool) -> Tuple[str, str]: def _create_bedrock_action_groups( - bedrock_client: Any, agent_id: str, tools: List[BaseTool] + bedrock_client: Any, agent_id: str, tools: List[BaseTool], enable_human_input: Optional[bool] = False ) -> None: """Create the bedrock action groups for the agent""" @@ -201,6 +201,7 @@ def _create_bedrock_action_groups( for tool in tools: action_group_name, function_name = _get_action_group_and_function_names(tool) tools_by_action_group[action_group_name].append(tool) + for action_group_name, functions in tools_by_action_group.items(): bedrock_client.create_agent_action_group( actionGroupName=action_group_name, @@ -213,6 +214,15 @@ def _create_bedrock_action_groups( agentVersion="DRAFT", ) + if enable_human_input: + bedrock_client.create_agent_action_group( + actionGroupName="UserInputAction", + parentActionGroupSignature="AMAZON.UserInput", + actionGroupState="ENABLED", + agentId=agent_id, + agentVersion="DRAFT", + ) + def _tool_to_function(tool: BaseTool) -> dict: """ @@ -395,6 +405,7 @@ def create_agent( bedrock_endpoint_url: Optional[str] = None, runtime_endpoint_url: Optional[str] = None, enable_trace: Optional[bool] = False, + enable_human_input: Optional[bool] = False, **kwargs: Any, ) -> BedrockAgentsRunnable: """ @@ -467,7 +478,7 @@ def create_agent( guardrail_configuration=guardrail_configuration, idle_session_ttl_in_seconds=idle_session_ttl_in_seconds, ) - _create_bedrock_action_groups(bedrock_client, agent_id, tools) + _create_bedrock_action_groups(bedrock_client, agent_id, tools, enable_human_input) _prepare_agent(bedrock_client, agent_id) except Exception as exception: logging.error(f"Error in create agent call: {exception}") From cb6f9c7a9f3f7d7f5a7bcbe876dcc3de0e1b9f0a Mon Sep 17 00:00:00 2001 From: John Baker Date: Thu, 21 Nov 2024 16:22:31 -0800 Subject: [PATCH 2/7] Add human feedback integration test. --- .../agents/test_bedrock_agents.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py index f65597cc..cc280d8b 100644 --- a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py +++ b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py @@ -4,6 +4,7 @@ import operator import time import uuid +from decimal import Decimal from typing import Any, Tuple, TypedDict, Union import boto3 @@ -435,3 +436,65 @@ def should_continue(data): _delete_agent_role(agent_resource_role_arn=agent_resource_role_arn) if agent_runnable: _delete_agent(agent_id=agent_runnable.agent_id) + + +def is_asking_location(response): + import re + # Common patterns for asking about location + patterns = [ + r'what (?:is )?(?:the )?location', + r'what (?:is )?(?:the )?location', + r'(?=.*location)(?=.*\?)', + r'(?=.*city)(?=.*\?)' + ] + + # Combine all patterns with OR operator and make case insensitive + combined_pattern = '|'.join(patterns) + + # Check if any pattern matches + return bool(re.search(combined_pattern, response.lower())) + + +def test_weather_agent_with_human_input(): + @tool + def get_weather(location: str) -> str: + """ + Get the weather of a location + + Args: + location: location of the place + """ + if location.lower() == "seattle": + return f"It is raining in {location}" + return f"It is hot and humid in {location}" + + foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0" + tools = [get_weather] + agent_resource_role_arn = None + agent = None + try: + agent_resource_role_arn = _create_agent_role( + agent_region="us-west-2", foundation_model=foundation_model + ) + agent = BedrockAgentsRunnable.create_agent( + agent_name="weather_agent", + agent_resource_role_arn=agent_resource_role_arn, + foundation_model=foundation_model, + instruction=""" + You are an agent who helps with getting weather for a given location. + If the user does not provide a location then ask for the location and be + sure to use the word 'location'.""", + tools=tools, + enable_humidity=True, + ) + agent_executor = AgentExecutor(agent=agent, tools=tools) # type: ignore[arg-type] + output = agent_executor.invoke({"input": "what is the weather?"}) + + assert is_asking_location(output["output"]) + except Exception as ex: + raise ex + finally: + if agent_resource_role_arn: + _delete_agent_role(agent_resource_role_arn) + if agent: + _delete_agent(agent.agent_id) From 018de49b1c2a0d1ea58806a2b5a7c675d3c76802 Mon Sep 17 00:00:00 2001 From: John Baker Date: Fri, 22 Nov 2024 14:50:47 -0800 Subject: [PATCH 3/7] Add the skip marker to the new integration test. --- libs/aws/tests/integration_tests/agents/test_bedrock_agents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py index cc280d8b..34c04ad6 100644 --- a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py +++ b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py @@ -454,7 +454,7 @@ def is_asking_location(response): # Check if any pattern matches return bool(re.search(combined_pattern, response.lower())) - +@pytest.mark.skip def test_weather_agent_with_human_input(): @tool def get_weather(location: str) -> str: From e7b601ea1a594008556141795f1e42f043c94101 Mon Sep 17 00:00:00 2001 From: John Baker Date: Fri, 22 Nov 2024 16:28:20 -0800 Subject: [PATCH 4/7] Fix lint issues. --- libs/aws/langchain_aws/agents/base.py | 9 +++++++-- .../integration_tests/agents/test_bedrock_agents.py | 13 +++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/libs/aws/langchain_aws/agents/base.py b/libs/aws/langchain_aws/agents/base.py index e9044b8f..557581bd 100644 --- a/libs/aws/langchain_aws/agents/base.py +++ b/libs/aws/langchain_aws/agents/base.py @@ -193,7 +193,10 @@ def _get_action_group_and_function_names(tool: BaseTool) -> Tuple[str, str]: def _create_bedrock_action_groups( - bedrock_client: Any, agent_id: str, tools: List[BaseTool], enable_human_input: Optional[bool] = False + bedrock_client: Any, + agent_id: str, + tools: List[BaseTool], + enable_human_input: Optional[bool] = False, ) -> None: """Create the bedrock action groups for the agent""" @@ -478,7 +481,9 @@ def create_agent( guardrail_configuration=guardrail_configuration, idle_session_ttl_in_seconds=idle_session_ttl_in_seconds, ) - _create_bedrock_action_groups(bedrock_client, agent_id, tools, enable_human_input) + _create_bedrock_action_groups( + bedrock_client, agent_id, tools, enable_human_input + ) _prepare_agent(bedrock_client, agent_id) except Exception as exception: logging.error(f"Error in create agent call: {exception}") diff --git a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py index 34c04ad6..74c83075 100644 --- a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py +++ b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py @@ -4,7 +4,6 @@ import operator import time import uuid -from decimal import Decimal from typing import Any, Tuple, TypedDict, Union import boto3 @@ -440,20 +439,22 @@ def should_continue(data): def is_asking_location(response): import re + # Common patterns for asking about location patterns = [ - r'what (?:is )?(?:the )?location', - r'what (?:is )?(?:the )?location', - r'(?=.*location)(?=.*\?)', - r'(?=.*city)(?=.*\?)' + r"what (?:is )?(?:the )?location", + r"what (?:is )?(?:the )?location", + r"(?=.*location)(?=.*\?)", + r"(?=.*city)(?=.*\?)", ] # Combine all patterns with OR operator and make case insensitive - combined_pattern = '|'.join(patterns) + combined_pattern = "|".join(patterns) # Check if any pattern matches return bool(re.search(combined_pattern, response.lower())) + @pytest.mark.skip def test_weather_agent_with_human_input(): @tool From d9c178a1ec8ad794b174f29ad5f7737f4a483c47 Mon Sep 17 00:00:00 2001 From: John Baker Date: Tue, 26 Nov 2024 16:08:06 -0800 Subject: [PATCH 5/7] Improve the human input test to check the agent configuration for the action group. --- libs/aws/langchain_aws/agents/base.py | 2 + .../agents/test_bedrock_agents.py | 58 +++++++++++++++++-- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/libs/aws/langchain_aws/agents/base.py b/libs/aws/langchain_aws/agents/base.py index 557581bd..8b38a123 100644 --- a/libs/aws/langchain_aws/agents/base.py +++ b/libs/aws/langchain_aws/agents/base.py @@ -448,6 +448,8 @@ def create_agent( runtime_endpoint_url: Endpoint URL for bedrock agent runtime enable_trace: Boolean flag to specify whether trace should be enabled when invoking the agent + enable_human_input: Boolean flag to specify whether a human as a tool should + be enabled for the agent. **kwargs: Additional arguments Returns: BedrockAgentsRunnable configured to invoke the Bedrock agent diff --git a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py index 74c83075..069406cd 100644 --- a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py +++ b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py @@ -455,6 +455,43 @@ def is_asking_location(response): return bool(re.search(combined_pattern, response.lower())) +def get_latest_agent_version(agent_id: str) -> str: + """ + Gets the latest version of a Bedrock Agent by creation date. + + Args: + agent_id (str): The ID of the Bedrock Agent + + Returns: + dict: The latest agent version information + + Raises: + Exception: If no agent versions are found or if API call fails + """ + # Initialize Bedrock Agents client + client = boto3.client("bedrock-agent") + + try: + # Get all versions of the agent + response = client.list_agent_versions(agentId=agent_id, maxResults=100) + + if not response.get("agentVersionSummaries"): + raise Exception(f"No versions found for agent {agent_id}") + + # Sort versions by creation date + versions = sorted( + response["agentVersionSummaries"], + key=lambda x: x["updatedAt"], + reverse=True, + ) + + # Return the most recent version + return str(versions[0]["agentVersion"]) + + except Exception as e: + raise Exception(f"Error getting agent versions: {str(e)}") + + @pytest.mark.skip def test_weather_agent_with_human_input(): @tool @@ -484,14 +521,25 @@ def get_weather(location: str) -> str: instruction=""" You are an agent who helps with getting weather for a given location. If the user does not provide a location then ask for the location and be - sure to use the word 'location'.""", + sure to use the word 'location'. """, tools=tools, - enable_humidity=True, + enable_human_input=True, ) - agent_executor = AgentExecutor(agent=agent, tools=tools) # type: ignore[arg-type] - output = agent_executor.invoke({"input": "what is the weather?"}) - assert is_asking_location(output["output"]) + # check human input is in the action groups + bedrock_client = boto3.client("bedrock-agent") + version = get_latest_agent_version(agent.agent_id) + paginator = bedrock_client.get_paginator("list_agent_action_groups") + for page in paginator.paginate( + agentId=agent.agent_id, + agentVersion=version, + PaginationConfig={"PageSize": 10}, + ): + for summary in page["actionGroupSummaries"]: + if str(summary["actionGroupName"]).lower() == "userinputactions": + return True + + return False except Exception as ex: raise ex finally: From 09c3a70a131fbdee766a0061701b49f9fb70937c Mon Sep 17 00:00:00 2001 From: John Baker Date: Tue, 26 Nov 2024 16:15:17 -0800 Subject: [PATCH 6/7] Remove unused function for integration test. --- .../agents/test_bedrock_agents.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py index 069406cd..12b7f87e 100644 --- a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py +++ b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py @@ -437,24 +437,6 @@ def should_continue(data): _delete_agent(agent_id=agent_runnable.agent_id) -def is_asking_location(response): - import re - - # Common patterns for asking about location - patterns = [ - r"what (?:is )?(?:the )?location", - r"what (?:is )?(?:the )?location", - r"(?=.*location)(?=.*\?)", - r"(?=.*city)(?=.*\?)", - ] - - # Combine all patterns with OR operator and make case insensitive - combined_pattern = "|".join(patterns) - - # Check if any pattern matches - return bool(re.search(combined_pattern, response.lower())) - - def get_latest_agent_version(agent_id: str) -> str: """ Gets the latest version of a Bedrock Agent by creation date. From 5629ab47e7cac97baad5f12b9a5407059ce47b66 Mon Sep 17 00:00:00 2001 From: John Baker Date: Wed, 27 Nov 2024 11:03:38 -0800 Subject: [PATCH 7/7] Add check for UserInputAction as ENABLED. Change to use assert as check for test. --- .../integration_tests/agents/test_bedrock_agents.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py index 12b7f87e..5f59f6d1 100644 --- a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py +++ b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py @@ -512,16 +512,21 @@ def get_weather(location: str) -> str: bedrock_client = boto3.client("bedrock-agent") version = get_latest_agent_version(agent.agent_id) paginator = bedrock_client.get_paginator("list_agent_action_groups") + has_human_input_tool = False for page in paginator.paginate( agentId=agent.agent_id, agentVersion=version, PaginationConfig={"PageSize": 10}, ): for summary in page["actionGroupSummaries"]: - if str(summary["actionGroupName"]).lower() == "userinputactions": - return True - - return False + if ( + str(summary["actionGroupName"]).lower() == "userinputaction" + and str(summary["actionGroupState"]).lower() == "enabled" + ): + has_human_input_tool = True + break + + assert has_human_input_tool except Exception as ex: raise ex finally: