Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a flag to enable human input for Bedrock Agents #289

Merged
merged 8 commits into from
Dec 4, 2024
22 changes: 20 additions & 2 deletions libs/aws/langchain_aws/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,18 @@ def _get_action_group_and_function_names(tool: BaseTool) -> Tuple[str, str]:


def _create_bedrock_action_groups(
bedrock_client: Any, agent_id: str, tools: List[BaseTool]
bedrock_client: Any,
agent_id: str,
tools: List[BaseTool],
enable_human_input: Optional[bool] = False,
) -> None:
"""Create the bedrock action groups for the agent"""

tools_by_action_group = defaultdict(list)
for tool in tools:
action_group_name, function_name = _get_action_group_and_function_names(tool)
tools_by_action_group[action_group_name].append(tool)

for action_group_name, functions in tools_by_action_group.items():
bedrock_client.create_agent_action_group(
actionGroupName=action_group_name,
Expand All @@ -213,6 +217,15 @@ def _create_bedrock_action_groups(
agentVersion="DRAFT",
)

if enable_human_input:
bedrock_client.create_agent_action_group(
actionGroupName="UserInputAction",
parentActionGroupSignature="AMAZON.UserInput",
actionGroupState="ENABLED",
agentId=agent_id,
agentVersion="DRAFT",
)


def _tool_to_function(tool: BaseTool) -> dict:
"""
Expand Down Expand Up @@ -395,6 +408,7 @@ def create_agent(
bedrock_endpoint_url: Optional[str] = None,
runtime_endpoint_url: Optional[str] = None,
enable_trace: Optional[bool] = False,
enable_human_input: Optional[bool] = False,
**kwargs: Any,
) -> BedrockAgentsRunnable:
"""
Expand Down Expand Up @@ -434,6 +448,8 @@ def create_agent(
runtime_endpoint_url: Endpoint URL for bedrock agent runtime
enable_trace: Boolean flag to specify whether trace should be enabled when
invoking the agent
enable_human_input: Boolean flag to specify whether a human as a tool should
be enabled for the agent.
**kwargs: Additional arguments
Returns:
BedrockAgentsRunnable configured to invoke the Bedrock agent
Expand Down Expand Up @@ -467,7 +483,9 @@ def create_agent(
guardrail_configuration=guardrail_configuration,
idle_session_ttl_in_seconds=idle_session_ttl_in_seconds,
)
_create_bedrock_action_groups(bedrock_client, agent_id, tools)
_create_bedrock_action_groups(
bedrock_client, agent_id, tools, enable_human_input
)
_prepare_agent(bedrock_client, agent_id)
except Exception as exception:
logging.error(f"Error in create agent call: {exception}")
Expand Down
99 changes: 99 additions & 0 deletions libs/aws/tests/integration_tests/agents/test_bedrock_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,102 @@ def should_continue(data):
_delete_agent_role(agent_resource_role_arn=agent_resource_role_arn)
if agent_runnable:
_delete_agent(agent_id=agent_runnable.agent_id)


def get_latest_agent_version(agent_id: str) -> str:
"""
Gets the latest version of a Bedrock Agent by creation date.

Args:
agent_id (str): The ID of the Bedrock Agent

Returns:
dict: The latest agent version information

Raises:
Exception: If no agent versions are found or if API call fails
"""
# Initialize Bedrock Agents client
client = boto3.client("bedrock-agent")

try:
# Get all versions of the agent
response = client.list_agent_versions(agentId=agent_id, maxResults=100)

if not response.get("agentVersionSummaries"):
raise Exception(f"No versions found for agent {agent_id}")

# Sort versions by creation date
versions = sorted(
response["agentVersionSummaries"],
key=lambda x: x["updatedAt"],
reverse=True,
)

# Return the most recent version
return str(versions[0]["agentVersion"])

except Exception as e:
raise Exception(f"Error getting agent versions: {str(e)}")


@pytest.mark.skip
def test_weather_agent_with_human_input():
@tool
def get_weather(location: str) -> str:
"""
Get the weather of a location

Args:
location: location of the place
"""
if location.lower() == "seattle":
return f"It is raining in {location}"
return f"It is hot and humid in {location}"

foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"
tools = [get_weather]
agent_resource_role_arn = None
agent = None
try:
agent_resource_role_arn = _create_agent_role(
agent_region="us-west-2", foundation_model=foundation_model
)
agent = BedrockAgentsRunnable.create_agent(
agent_name="weather_agent",
agent_resource_role_arn=agent_resource_role_arn,
foundation_model=foundation_model,
instruction="""
You are an agent who helps with getting weather for a given location.
If the user does not provide a location then ask for the location and be
sure to use the word 'location'. """,
tools=tools,
enable_human_input=True,
)

# check human input is in the action groups
bedrock_client = boto3.client("bedrock-agent")
version = get_latest_agent_version(agent.agent_id)
paginator = bedrock_client.get_paginator("list_agent_action_groups")
has_human_input_tool = False
for page in paginator.paginate(
agentId=agent.agent_id,
agentVersion=version,
PaginationConfig={"PageSize": 10},
):
for summary in page["actionGroupSummaries"]:
if (
str(summary["actionGroupName"]).lower() == "userinputaction"
and str(summary["actionGroupState"]).lower() == "enabled"
):
has_human_input_tool = True
break

assert has_human_input_tool
except Exception as ex:
raise ex
finally:
if agent_resource_role_arn:
_delete_agent_role(agent_resource_role_arn)
if agent:
_delete_agent(agent.agent_id)