Skip to content

Commit

Permalink
Added a flag to enable human input for Bedrock Agents (#289)
Browse files Browse the repository at this point in the history
- Simplify the create of an agent which uses human input by including a
flag (False by default to match current behavior).
- Add new integration test for the use case.
  • Loading branch information
jdbaker01 authored Dec 4, 2024

Verified

This commit was signed with the committer’s verified signature.
yijiasu-crypto Yijia Su
1 parent 28f3718 commit f7dc810
Showing 2 changed files with 119 additions and 2 deletions.
22 changes: 20 additions & 2 deletions libs/aws/langchain_aws/agents/base.py
Original file line number Diff line number Diff line change
@@ -193,14 +193,18 @@ def _get_action_group_and_function_names(tool: BaseTool) -> Tuple[str, str]:


def _create_bedrock_action_groups(
bedrock_client: Any, agent_id: str, tools: List[BaseTool]
bedrock_client: Any,
agent_id: str,
tools: List[BaseTool],
enable_human_input: Optional[bool] = False,
) -> None:
"""Create the bedrock action groups for the agent"""

tools_by_action_group = defaultdict(list)
for tool in tools:
action_group_name, function_name = _get_action_group_and_function_names(tool)
tools_by_action_group[action_group_name].append(tool)

for action_group_name, functions in tools_by_action_group.items():
bedrock_client.create_agent_action_group(
actionGroupName=action_group_name,
@@ -213,6 +217,15 @@ def _create_bedrock_action_groups(
agentVersion="DRAFT",
)

if enable_human_input:
bedrock_client.create_agent_action_group(
actionGroupName="UserInputAction",
parentActionGroupSignature="AMAZON.UserInput",
actionGroupState="ENABLED",
agentId=agent_id,
agentVersion="DRAFT",
)


def _tool_to_function(tool: BaseTool) -> dict:
"""
@@ -395,6 +408,7 @@ def create_agent(
bedrock_endpoint_url: Optional[str] = None,
runtime_endpoint_url: Optional[str] = None,
enable_trace: Optional[bool] = False,
enable_human_input: Optional[bool] = False,
**kwargs: Any,
) -> BedrockAgentsRunnable:
"""
@@ -434,6 +448,8 @@ def create_agent(
runtime_endpoint_url: Endpoint URL for bedrock agent runtime
enable_trace: Boolean flag to specify whether trace should be enabled when
invoking the agent
enable_human_input: Boolean flag to specify whether a human as a tool should
be enabled for the agent.
**kwargs: Additional arguments
Returns:
BedrockAgentsRunnable configured to invoke the Bedrock agent
@@ -467,7 +483,9 @@ def create_agent(
guardrail_configuration=guardrail_configuration,
idle_session_ttl_in_seconds=idle_session_ttl_in_seconds,
)
_create_bedrock_action_groups(bedrock_client, agent_id, tools)
_create_bedrock_action_groups(
bedrock_client, agent_id, tools, enable_human_input
)
_prepare_agent(bedrock_client, agent_id)
except Exception as exception:
logging.error(f"Error in create agent call: {exception}")
99 changes: 99 additions & 0 deletions libs/aws/tests/integration_tests/agents/test_bedrock_agents.py
Original file line number Diff line number Diff line change
@@ -435,3 +435,102 @@ def should_continue(data):
_delete_agent_role(agent_resource_role_arn=agent_resource_role_arn)
if agent_runnable:
_delete_agent(agent_id=agent_runnable.agent_id)


def get_latest_agent_version(agent_id: str) -> str:
"""
Gets the latest version of a Bedrock Agent by creation date.
Args:
agent_id (str): The ID of the Bedrock Agent
Returns:
dict: The latest agent version information
Raises:
Exception: If no agent versions are found or if API call fails
"""
# Initialize Bedrock Agents client
client = boto3.client("bedrock-agent")

try:
# Get all versions of the agent
response = client.list_agent_versions(agentId=agent_id, maxResults=100)

if not response.get("agentVersionSummaries"):
raise Exception(f"No versions found for agent {agent_id}")

# Sort versions by creation date
versions = sorted(
response["agentVersionSummaries"],
key=lambda x: x["updatedAt"],
reverse=True,
)

# Return the most recent version
return str(versions[0]["agentVersion"])

except Exception as e:
raise Exception(f"Error getting agent versions: {str(e)}")


@pytest.mark.skip
def test_weather_agent_with_human_input():
@tool
def get_weather(location: str) -> str:
"""
Get the weather of a location
Args:
location: location of the place
"""
if location.lower() == "seattle":
return f"It is raining in {location}"
return f"It is hot and humid in {location}"

foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"
tools = [get_weather]
agent_resource_role_arn = None
agent = None
try:
agent_resource_role_arn = _create_agent_role(
agent_region="us-west-2", foundation_model=foundation_model
)
agent = BedrockAgentsRunnable.create_agent(
agent_name="weather_agent",
agent_resource_role_arn=agent_resource_role_arn,
foundation_model=foundation_model,
instruction="""
You are an agent who helps with getting weather for a given location.
If the user does not provide a location then ask for the location and be
sure to use the word 'location'. """,
tools=tools,
enable_human_input=True,
)

# check human input is in the action groups
bedrock_client = boto3.client("bedrock-agent")
version = get_latest_agent_version(agent.agent_id)
paginator = bedrock_client.get_paginator("list_agent_action_groups")
has_human_input_tool = False
for page in paginator.paginate(
agentId=agent.agent_id,
agentVersion=version,
PaginationConfig={"PageSize": 10},
):
for summary in page["actionGroupSummaries"]:
if (
str(summary["actionGroupName"]).lower() == "userinputaction"
and str(summary["actionGroupState"]).lower() == "enabled"
):
has_human_input_tool = True
break

assert has_human_input_tool
except Exception as ex:
raise ex
finally:
if agent_resource_role_arn:
_delete_agent_role(agent_resource_role_arn)
if agent:
_delete_agent(agent.agent_id)

0 comments on commit f7dc810

Please sign in to comment.