diff --git a/libs/aws/langchain_aws/agents/base.py b/libs/aws/langchain_aws/agents/base.py index eddd49f7..8b38a123 100644 --- a/libs/aws/langchain_aws/agents/base.py +++ b/libs/aws/langchain_aws/agents/base.py @@ -193,7 +193,10 @@ def _get_action_group_and_function_names(tool: BaseTool) -> Tuple[str, str]: def _create_bedrock_action_groups( - bedrock_client: Any, agent_id: str, tools: List[BaseTool] + bedrock_client: Any, + agent_id: str, + tools: List[BaseTool], + enable_human_input: Optional[bool] = False, ) -> None: """Create the bedrock action groups for the agent""" @@ -201,6 +204,7 @@ def _create_bedrock_action_groups( for tool in tools: action_group_name, function_name = _get_action_group_and_function_names(tool) tools_by_action_group[action_group_name].append(tool) + for action_group_name, functions in tools_by_action_group.items(): bedrock_client.create_agent_action_group( actionGroupName=action_group_name, @@ -213,6 +217,15 @@ def _create_bedrock_action_groups( agentVersion="DRAFT", ) + if enable_human_input: + bedrock_client.create_agent_action_group( + actionGroupName="UserInputAction", + parentActionGroupSignature="AMAZON.UserInput", + actionGroupState="ENABLED", + agentId=agent_id, + agentVersion="DRAFT", + ) + def _tool_to_function(tool: BaseTool) -> dict: """ @@ -395,6 +408,7 @@ def create_agent( bedrock_endpoint_url: Optional[str] = None, runtime_endpoint_url: Optional[str] = None, enable_trace: Optional[bool] = False, + enable_human_input: Optional[bool] = False, **kwargs: Any, ) -> BedrockAgentsRunnable: """ @@ -434,6 +448,8 @@ def create_agent( runtime_endpoint_url: Endpoint URL for bedrock agent runtime enable_trace: Boolean flag to specify whether trace should be enabled when invoking the agent + enable_human_input: Boolean flag to specify whether a human as a tool should + be enabled for the agent. **kwargs: Additional arguments Returns: BedrockAgentsRunnable configured to invoke the Bedrock agent @@ -467,7 +483,9 @@ def create_agent( guardrail_configuration=guardrail_configuration, idle_session_ttl_in_seconds=idle_session_ttl_in_seconds, ) - _create_bedrock_action_groups(bedrock_client, agent_id, tools) + _create_bedrock_action_groups( + bedrock_client, agent_id, tools, enable_human_input + ) _prepare_agent(bedrock_client, agent_id) except Exception as exception: logging.error(f"Error in create agent call: {exception}") diff --git a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py index f65597cc..5f59f6d1 100644 --- a/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py +++ b/libs/aws/tests/integration_tests/agents/test_bedrock_agents.py @@ -435,3 +435,102 @@ def should_continue(data): _delete_agent_role(agent_resource_role_arn=agent_resource_role_arn) if agent_runnable: _delete_agent(agent_id=agent_runnable.agent_id) + + +def get_latest_agent_version(agent_id: str) -> str: + """ + Gets the latest version of a Bedrock Agent by creation date. + + Args: + agent_id (str): The ID of the Bedrock Agent + + Returns: + dict: The latest agent version information + + Raises: + Exception: If no agent versions are found or if API call fails + """ + # Initialize Bedrock Agents client + client = boto3.client("bedrock-agent") + + try: + # Get all versions of the agent + response = client.list_agent_versions(agentId=agent_id, maxResults=100) + + if not response.get("agentVersionSummaries"): + raise Exception(f"No versions found for agent {agent_id}") + + # Sort versions by creation date + versions = sorted( + response["agentVersionSummaries"], + key=lambda x: x["updatedAt"], + reverse=True, + ) + + # Return the most recent version + return str(versions[0]["agentVersion"]) + + except Exception as e: + raise Exception(f"Error getting agent versions: {str(e)}") + + +@pytest.mark.skip +def test_weather_agent_with_human_input(): + @tool + def get_weather(location: str) -> str: + """ + Get the weather of a location + + Args: + location: location of the place + """ + if location.lower() == "seattle": + return f"It is raining in {location}" + return f"It is hot and humid in {location}" + + foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0" + tools = [get_weather] + agent_resource_role_arn = None + agent = None + try: + agent_resource_role_arn = _create_agent_role( + agent_region="us-west-2", foundation_model=foundation_model + ) + agent = BedrockAgentsRunnable.create_agent( + agent_name="weather_agent", + agent_resource_role_arn=agent_resource_role_arn, + foundation_model=foundation_model, + instruction=""" + You are an agent who helps with getting weather for a given location. + If the user does not provide a location then ask for the location and be + sure to use the word 'location'. """, + tools=tools, + enable_human_input=True, + ) + + # check human input is in the action groups + bedrock_client = boto3.client("bedrock-agent") + version = get_latest_agent_version(agent.agent_id) + paginator = bedrock_client.get_paginator("list_agent_action_groups") + has_human_input_tool = False + for page in paginator.paginate( + agentId=agent.agent_id, + agentVersion=version, + PaginationConfig={"PageSize": 10}, + ): + for summary in page["actionGroupSummaries"]: + if ( + str(summary["actionGroupName"]).lower() == "userinputaction" + and str(summary["actionGroupState"]).lower() == "enabled" + ): + has_human_input_tool = True + break + + assert has_human_input_tool + except Exception as ex: + raise ex + finally: + if agent_resource_role_arn: + _delete_agent_role(agent_resource_role_arn) + if agent: + _delete_agent(agent.agent_id)