diff --git a/.github/workflows/tester.yml b/.github/workflows/tester.yml index 06bf8af7..9b1392c1 100644 --- a/.github/workflows/tester.yml +++ b/.github/workflows/tester.yml @@ -21,7 +21,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pandas tqdm + pip install pytest pandas tqdm IPython pip install -e . pip install -r requirements.txt - name: Running Tests diff --git a/requirements.txt b/requirements.txt index 35307189..685940f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,10 +11,6 @@ gspread==5.10.0 gspread_dataframe numpy requests -pylint==2.8.3 -pytest==6.0.2 -pytest-cov==2.11.1 -pytest-xdist==2.1.0 pyyaml rouge-score torch diff --git a/src/dfcx_scrapi/core/environments.py b/src/dfcx_scrapi/core/environments.py index 5aa361f3..847625ad 100644 --- a/src/dfcx_scrapi/core/environments.py +++ b/src/dfcx_scrapi/core/environments.py @@ -46,11 +46,17 @@ def __init__( creds=creds, ) - if agent_id: - self.agent_id = agent_id - + self.agent_id = agent_id self._versions = versions.Versions(creds=self.creds) self._flows = flows.Flows(creds=self.creds) + self._environments_map = None + + @property + def environments_map(self) -> Dict[str, str]: + """Property for the environments map.""" + if self._environments_map is None: + self._environments_map = self.get_environments_map(self.agent_id) + return self._environments_map @staticmethod def _get_flow_version_id( @@ -125,6 +131,8 @@ def get_environments_map( for environment in self.list_environments(agent_id) } + self._environments_map = environments_dict + return environments_dict @scrapi_base.api_call_counter_decorator diff --git a/src/dfcx_scrapi/core/scrapi_base.py b/src/dfcx_scrapi/core/scrapi_base.py index 343b36cb..bebdb5c0 100644 --- a/src/dfcx_scrapi/core/scrapi_base.py +++ b/src/dfcx_scrapi/core/scrapi_base.py @@ -135,6 +135,9 @@ def _set_region(resource_id: str): instantiating other library client objects, or None if the location is "global" """ + if not resource_id: + raise ValueError("resource_id must not be None.") + try: location = resource_id.split("/")[3] except IndexError as err: @@ -334,8 +337,8 @@ def _parse_resource_path( "format": "`projects//locations//securitySettings/`", # noqa: E501 }, "session": { - "matcher": fr"{matcher_root}/agents/(?P{standard_id_match})/sessions/(?P{session_id_match})$", # noqa: E501 - "format": "`projects//locations//agents//sessions/`", # noqa: E501 + "matcher": fr"{matcher_root}/agents/(?P{standard_id_match})(?:/environments/(?P{standard_id_match}))?/sessions/(?P{session_id_match})$", # noqa: E501 + "format": "`projects//locations//agents//sessions/` or `projects//locations//agents//environments//sessions/`", # noqa: E501 }, "session_entity_type": { "matcher": fr"{matcher_root}/agents/(?P{standard_id_match})/sessions/(?P{session_id_match})/entityTypes/(?P{entity_id_match})$", # noqa: E501 diff --git a/src/dfcx_scrapi/core/sessions.py b/src/dfcx_scrapi/core/sessions.py index 8f2d1ff2..44ea5030 100644 --- a/src/dfcx_scrapi/core/sessions.py +++ b/src/dfcx_scrapi/core/sessions.py @@ -16,14 +16,18 @@ import logging import uuid -from typing import Dict, List +from typing import Any, Dict, List from google.cloud.dialogflowcx_v3beta1 import services, types from google.protobuf.json_format import MessageToDict from IPython.display import Markdown, display from proto.marshal.collections import maps -from dfcx_scrapi.core import flows, playbooks, scrapi_base, tools +from dfcx_scrapi.core.environments import Environments +from dfcx_scrapi.core.flows import Flows +from dfcx_scrapi.core.playbooks import Playbooks +from dfcx_scrapi.core.scrapi_base import ScrapiBase +from dfcx_scrapi.core.tools import Tools # logging config logging.basicConfig( @@ -33,7 +37,7 @@ ) -class Sessions(scrapi_base.ScrapiBase): +class Sessions(ScrapiBase): """Core Class for CX Session Resource functions.""" def __init__( @@ -53,22 +57,59 @@ def __init__( creds=creds, scope=scope ) - self.session_id = session_id + self._session_id = session_id self.agent_id = agent_id self.tools_map = tools_map self.playbooks_map = playbooks_map self.flows_map = flows_map + self._env_client = None + self._tools_client = None + self._playbooks_client = None + self._flows_client = None @property def session_id(self): + """Property for the session ID, parses the resource path if needed.""" + if self._session_id: + self._parse_resource_path("session", self._session_id) + return self._session_id - @session_id.setter - def session_id(self, value): - if value: - self._parse_resource_path("session", value) + @property + def playbooks_client(self): + """Property for Playbooks client.""" + if self._playbooks_client is None: + if not self.agent_id: + raise ValueError( + "agent_id must be set to use Playbooks Client.") + self._playbooks_client = Playbooks( + agent_id=self.agent_id, creds=self.creds + ) + return self._playbooks_client + + @property + def tools_client(self): + """Property for Tools client.""" + if self._tools_client is None: + self._tools_client = Tools(creds=self.creds) + + return self._tools_client + + @property + def flows_client(self): + """Property for Flows client.""" + if self._flows_client is None: + self._flows_client = Flows(creds=self.creds) - self._session_id = value + return self._flows_client + + @property + def env_client(self): + """Property for Environments client.""" + if self._env_client is None: + self._env_client = Environments(creds=self.creds) + + return self._env_client @staticmethod def printmd(string): @@ -121,25 +162,22 @@ def get_tool_params(self, params: maps.MapComposite): def get_playbook_name(self, playbook_id: str): agent_id = self.parse_agent_id(playbook_id) if not self.playbooks_map: - playbook_client = playbooks.Playbooks( - agent_id=agent_id, creds=self.creds - ) - self.playbooks_map = playbook_client.get_playbooks_map(agent_id) + self.playbooks_map = self.playbooks_client.get_playbooks_map( + agent_id) return self.playbooks_map[playbook_id] def get_tool_name(self, tool_use: types.example.ToolUse) -> str: agent_id = self.parse_agent_id(tool_use.tool) if not self.tools_map: - tool_client = tools.Tools(creds=self.creds) - self.tools_map = tool_client.get_tools_map(agent_id) + self.tools_map = self.tools_client.get_tools_map(agent_id) + return self.tools_map[tool_use.tool] def get_flow_name(self, flow_id: str): agent_id = self.parse_agent_id(flow_id) if not self.flows_map: - flow_client = flows.Flows(creds=self.creds) - self.flows_map = flow_client.get_flows_map(agent_id) + self.flows_map = self.flows_client.get_flows_map(agent_id) return self.flows_map[flow_id] @@ -208,125 +246,52 @@ def collect_flow_responses( return flow_responses def build_session_id( - self, agent_id: str = None, overwrite: bool = True + self, agent_id: str = None, overwrite: bool = True, + environment_name: str = None ) -> str: """Creates a valid UUID-4 Session ID to use with other methods. Args: + agent_id: the Agent ID of the CX Agent. overwrite (Optional), if a session_id already exists, this will overwrite the existing Session ID parameter. Defaults to True. + environment_name: (Optional) the human readable Environment name to + use when building the session ID. If this is not provided, DRAFT is + assumed. """ - agent_parts = self._parse_resource_path("agent", agent_id) - session_id = ( - f"projects/{agent_parts['project']}/" - f"locations/{agent_parts['location']}/agents/" - f"{agent_parts['agent']}/sessions/{uuid.uuid4()}" - ) - - if overwrite: - self.session_id = session_id - - return session_id - - def run_conversation( - self, - agent_id: str = None, - session_id: str = None, - conversation: List[str] = None, - parameters=None, - response_text=False, - ): - """Tests a full conversation with the specified CX Agent. - - Args: - agent_id: the Agent ID of the CX Agent to have the conversation with. - session_id: an RFC 4122 formatted UUID to be used as the unique ID - for the duration of the conversation session. When using Python - uuid library, uuid.uuid4() is preferred. - conversation: a List of Strings that represent the USER utterances - for the given conversation, in the order they would happen - chronologically in the conversation. - Ex: - ['I want to check my bill', 'yes', 'no that is all', 'thanks!'] - parameters: (Optional) Dict of CX Session Parameters to set in the - conversation. Typically this is set before a conversation starts. - response_text: Will provide the Agent Response text if set to True. - Default value is False. - - Returns: - None, the conversation Request/Response is printed to console. - """ - if not session_id: - session_id = self.session_id - - client_options = self._set_region(agent_id) - session_client = services.sessions.SessionsClient( - client_options=client_options, credentials=self.creds - ) - session_path = f"{agent_id}/sessions/{session_id}" - - if parameters: - query_params = types.session.QueryParameters(parameters=parameters) - text_input = types.session.TextInput(text="") - query_input = types.session.QueryInput( - text=text_input, language_code="en" - ) - request = types.session.DetectIntentRequest( - session=session_path, - query_params=query_params, - query_input=query_input, - ) + # Parse and validate the incoming agent_id + _ = self._parse_resource_path("agent", agent_id) - response = session_client.detect_intent(request=request) - for text in conversation: - text_input = types.session.TextInput(text=text) - query_input = types.session.QueryInput( - text=text_input, language_code="en" - ) - request = types.session.DetectIntentRequest( - session=session_path, query_input=query_input + if environment_name: + env = self._env_client.get_environment_by_display_name( + environment_name, agent_id ) - response = session_client.detect_intent(request=request) - query_result = response.query_result - - print("=" * 20) - print(f"Query text: {query_result.text}") - if "intent" in query_result: - print(f"Triggered Intent: {query_result.intent.display_name}") - - if "intent_detection_confidence" in query_result: - print( - f"Intent Confidence: \ - f{query_result.intent_detection_confidence}" - ) + if not env: + raise ValueError(f"Environment `{environment_name}` does not" + " exist.") + session_id = f"{env.name}/sessions/{uuid.uuid4()}" - print(f"Response Page: {query_result.current_page.display_name}") + else: + session_id = f"{agent_id}/sessions/{uuid.uuid4()}" - for param in query_result.parameters: - if param == "statusMessage": - print(f"Status Message: {query_result.parameters[param]}") + if overwrite: + self._session_id = session_id - if response_text: - concat_messages = " ".join( - [ - " ".join(response_message.text.text) - for response_message in query_result.response_messages - ] - ) - print(f"Response Text: {concat_messages}\n") + return session_id def detect_intent( self, agent_id, session_id, - text, - language_code="en", - parameters=None, - end_user_metadata=None, - populate_data_store_connection_signals=False, - intent_id: str = None + text: str = None, + language_code: str = "en", + parameters: Dict[str, Any] = None, + end_user_metadata: Dict[str, Any] = None, + populate_data_store_connection_signals: bool = False, + intent_id: str = None, + timezone: str = None ): """Returns the result of detect intent with texts as inputs. @@ -338,7 +303,9 @@ def detect_intent( session_id: an RFC 4122 formatted UUID to be used as the unique ID for the duration of the conversation session. When using Python uuid library, uuid.uuid4() is preferred. - text: the user utterance to run intent detection on + text: (Optional) the user utterance to run intent detection on + language_code: (Optional) corresponds to the language code to use with + query inputs to the agent. parameters: (Optional) Dict of CX Session Parameters to set in the conversation. Typically this is set before a conversation starts. end_user_metadata: (Optional) Dict of CX Session endUserMetadata to @@ -347,6 +314,14 @@ def detect_intent( stores are involved in serving the request then query result will be populated with data_store_connection_signals field which contains data that can help evaluations. + intent_id: fully qualified Intent ID path to pass in for query + input instead of text. This allows for the direct triggering of a + specific Intent, and will bypass the NLU engine. + timezone: (Optional) IANA Timezone database code to pass in with query + input which can be used by the agent runtime. For example, when + capturing datetime via system functions, they can be modified to + user the provied timezone vs. the default agent timezone. + Refs: https://www.iana.org/time-zones Returns: The CX query result from intent detection @@ -388,52 +363,18 @@ def detect_intent( "populate_data_store_connection_signals" ] = populate_data_store_connection_signals + if timezone: + query_param_mapping["time_zone"] = timezone + if query_param_mapping: query_params = types.session.QueryParameters(query_param_mapping) request.query_params = query_params - response = session_client.detect_intent(request) + response = session_client.detect_intent(request=request) query_result = response.query_result return query_result - def preset_parameters( - self, agent_id: str = None, session_id: str = None, parameters=None - ): - """Used to set session parameters before a conversation starts. - - Args: - agent_id: the Agent ID of the CX Agent to have the conversation with. - session_id: an RFC 4122 formatted UUID to be used as the unique ID - for the duration of the conversation session. When using Python - uuid library, uuid.uuid4() is preferred. - parameters: Dict of CX Session Parameters to set in the - conversation. Typically this is set before a conversation starts. - - Returns: - The CX query result from intent detection run with no text input - """ - client_options = self._set_region(agent_id) - session_client = services.sessions.SessionsClient( - client_options=client_options, credentials=self.creds - ) - session_path = f"{agent_id}/sessions/{session_id}" - - query_params = types.session.QueryParameters(parameters=parameters) - text_input = types.session.TextInput(text=None) - query_input = types.session.QueryInput( - text=text_input, language_code="en" - ) - request = types.session.DetectIntentRequest( - session=session_path, - query_params=query_params, - query_input=query_input, - ) - - response = session_client.detect_intent(request=request) - - return response - def get_agent_answer(self, user_query: str) -> str: """Extract the answer/citation from a Vertex Conversation response.""" diff --git a/tests/dfcx_scrapi/core/test_sessions.py b/tests/dfcx_scrapi/core/test_sessions.py new file mode 100644 index 00000000..e3fb6e30 --- /dev/null +++ b/tests/dfcx_scrapi/core/test_sessions.py @@ -0,0 +1,779 @@ +"""Unit Tests for Sessions.""" +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch +import uuid + +import pytest +from google.protobuf import struct_pb2 +from google.cloud.dialogflowcx_v3beta1 import types + +from dfcx_scrapi.core.sessions import Sessions + +@pytest.fixture +def test_config(): + project_id = "my-project-id-1234" + location_id = "global" + agent_id = f"projects/{project_id}/locations/{location_id}/agents/fcdecc6a-3f2e-4f8d-abca-63426024d8bb" + intent_id = f"{agent_id}/intents/b1983e56-5c96-4b20-b15a-fb5c12b77500" + flow_id = f"{agent_id}/flows/925c0042-686f-4422-bf34-2a2a37b1a3ee" + tool_id = f"{agent_id}/tools/8e1205e8-bb5e-46ef-896c-1e4f72bb871c" + page_id = f"{flow_id}/pages/8e1205e8-bb5e-46ef-896c-1e4f72bb871c" + playbook_id = f"{agent_id}/playbooks/e79502db-78e0-4f54-8447-716107bb553e" + playbook_name = "mock_playbook" + environment_id = f"{agent_id}/environments/753de31b-a14b-45f3-a731-b40831ecfbc4" + environment_name = "My Test Environment" + session_id_plain = f"{agent_id}/sessions/a1b2c3d4-e5f6-7890-1234-567890abcdef" + session_id_with_env = f"{environment_id}/sessions/a1b2c3d4-e5f6-7890-1234-567890abcdef" + language_code = "en" + text = "Hello!" + tool_action = "my_tool_action" + tool_input_params = {"param1": "value1", "param2": "value2"} + tool_output_params = {"param3": "value3", "param4": "value4"} + tool_input_params_nested = {"top_key": {"param1": "value1", "param2": "value2"}} + tool_output_params_nested = {"top_key": {"param3": "value3", "param4": "value4"}} + + parameters = {"param1": "value1", "param2": "value2"} + end_user_metadata = {"user_id": "user123"} + + return { + "project_id": project_id, + "location_id": location_id, + "agent_id": agent_id, + "session_id_plain": session_id_plain, + "session_id_with_env": session_id_with_env, + "intent_id": intent_id, + "flow_id": flow_id, + "page_id": page_id, + "tool_id": tool_id, + "playbook_id": playbook_id, + "playbook_name": playbook_name, + "environment_id": environment_id, + "environment_name": environment_name, + "language_code": language_code, + "text": text, + "tool_action": tool_action, + "tool_input_params": tool_input_params, + "tool_output_params": tool_output_params, + "tool_input_params_nested": tool_input_params_nested, + "tool_output_params_nested": tool_output_params_nested, + "parameters": parameters, + "end_user_metadata": end_user_metadata, + } + +@pytest.fixture +def mock_query_result(test_config): + """Create a mock QueryResult object for testing, without generative info.""" + return types.session.QueryResult( + text=test_config["text"], + language_code=test_config["language_code"], + parameters=struct_pb2.Struct( + fields={ + "some_key": struct_pb2.Value(string_value="some_value") + } + ), + response_messages=[ + types.ResponseMessage( + text=types.ResponseMessage.Text( + text=["Greetings! How can I assist?"] + ) + ), + types.ResponseMessage( + text=types.ResponseMessage.Text( + text=["Hey! What can I help you with today?"] + ) + ) + ], + intent_detection_confidence=1, + match=types.session.Match( + match_type=types.session.Match.MatchType.PLAYBOOK, + confidence=1 + ), + advanced_settings = types.AdvancedSettings( + speech_settings = types.AdvancedSettings.SpeechSettings( + endpointer_sensitivity=90, + ), + logging_settings = types.AdvancedSettings.LoggingSettings( + enable_stackdriver_logging=True, + enable_interaction_logging=True + ) + ), + ) + +@pytest.fixture +def mock_query_result_tools_playbooks_flows(test_config): + """Create a mock QueryResult object for testing, with generative info.""" + return types.session.QueryResult( + text=test_config["text"], + language_code=test_config["language_code"], + generative_info = types.session.GenerativeInfo( + action_tracing_info=types.Example( + actions=[ + types.Action( + tool_use=types.example.ToolUse( + tool=test_config["tool_id"], + action=test_config["tool_action"], + input_action_parameters=test_config["tool_input_params"], + output_action_parameters=test_config["tool_output_params"], + ), + ), + types.Action( + playbook_invocation=types.example.PlaybookInvocation( + playbook=test_config["playbook_id"] + ), + ), + types.Action( + flow_invocation=types.example.FlowInvocation( + flow=test_config["flow_id"] + ) + ), + types.Action( + agent_utterance=types.AgentUtterance(text="Hey there!") + ) + ] + ), + current_playbooks=[test_config["playbook_id"]] + ), + ) + +@pytest.fixture +def mock_query_result_datastore(test_config): + """Create a mock QueryResult object for testing datastore responses.""" + return types.session.QueryResult( + text="who is the ceo?", + language_code="en", + parameters=struct_pb2.Struct( + fields={ + "some_key": struct_pb2.Value(string_value="some_value") + } + ), + response_messages=[ + types.ResponseMessage( + text=types.ResponseMessage.Text( + text=["Sundar Pichai is the CEO of Google.\nhttps://www.google.com"] + ) + ), + types.ResponseMessage( + payload=struct_pb2.Struct( + fields={ + "richContent": struct_pb2.Value( + list_value=struct_pb2.ListValue( + values=[ + struct_pb2.Value( + list_value=struct_pb2.ListValue( + values=[ + struct_pb2.Value( + struct_value=struct_pb2.Struct( + fields={ + "type": struct_pb2.Value(string_value="info"), + "title": struct_pb2.Value(string_value="CEO of Google"), + "subtitle": struct_pb2.Value(string_value="Information on Google executive team."), + "metadata": struct_pb2.Value(struct_value=struct_pb2.Struct()), + "actionLink": struct_pb2.Value(string_value="https://www.google.com"), + } + ) + ) + ] + ) + ) + ] + ) + ) + } + ) + ) + ], + intent_detection_confidence=1, + match=types.session.Match( + match_type=types.session.Match.MatchType.KNOWLEDGE_CONNECTOR, + confidence=1 + ), + advanced_settings = types.AdvancedSettings( + logging_settings = types.AdvancedSettings.LoggingSettings( + enable_stackdriver_logging=True, + enable_interaction_logging=True + ) + ), + current_page = types.Page( + name=test_config["page_id"], + display_name="Start Page" + ) + ) + +@pytest.fixture +def mock_query_result_params_no_text_input(test_config): + """Create a mock QueryResult object for testing.""" + return types.session.QueryResult( + language_code=test_config["language_code"], + ) + +@pytest.fixture +def mock_detect_intent_response(mock_query_result): + """Create a mock DetectIntentResponse object for testing.""" + return types.session.DetectIntentResponse(query_result=mock_query_result) + +@pytest.fixture +def mock_detect_intent_response_no_text_input( + mock_query_result_params_no_text_input): + """Create a mock DetectIntentResponse object for testing.""" + return types.session.DetectIntentResponse( + query_result=mock_query_result_params_no_text_input) + +@pytest.fixture +def mock_detect_intent_response_tools_playbooks_flows( + mock_query_result_tools_playbooks_flows): + """Create a mock DetectIntentResponse object for testing.""" + return types.session.DetectIntentResponse( + query_result=mock_query_result_tools_playbooks_flows) + +@pytest.fixture +def mock_environment_obj(test_config): + """Create a mock Environment object for testing.""" + return types.Environment( + name=test_config["environment_id"], + display_name=test_config["environment_name"] + ) + +@pytest.fixture(autouse=True) +def mock_sessions_client(test_config): + """Fixture to create a mocked SessionsClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.sessions.services.sessions.SessionsClient") as mock_client, \ + patch("dfcx_scrapi.core.environments.Environments.__init__") as mock_env, \ + patch("dfcx_scrapi.core.tools.Tools.__init__") as mock_tools, \ + patch("dfcx_scrapi.core.playbooks.Playbooks.__init__") as mock_playbooks, \ + patch("dfcx_scrapi.core.flows.Flows.__init__") as mock_flows: + + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + mock_env.return_value = None + mock_tools.return_value = None + mock_playbooks.return_value = None + mock_flows.return_value = None + + yield mock_client + +@pytest.fixture +def mock_environments_client(test_config): + """Fixture to create a mocked EnvironmentsClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.environments.services.environments.EnvironmentsClient") as mock_client: + + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client + + +@pytest.fixture +def mock_tools_client(test_config): + """Fixture to create a mocked ToolsClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") as mock_client: + + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client + +@pytest.fixture +def mock_playbooks_client(test_config): + """Fixture to create a mocked PlaybooksClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") as mock_client: + + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client + +@pytest.fixture +def mock_flows_client(test_config): + """Fixture to create a mocked FlowsClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.flows.services.flows.FlowsClient") as mock_client: + + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client + +@pytest.fixture +def mock_tools_instance(test_config): + tools_instance = MagicMock() + tools_instance.get_tools_map.return_value = { + test_config["tool_id"]: "mock tool" + } + + return tools_instance + +def test_session_id_property_valid_id(test_config): + """Test session_id property with a valid session ID.""" + session = Sessions( + agent_id = test_config["agent_id"], + session_id = test_config["session_id_plain"]) + assert session.session_id == test_config["session_id_plain"] + +def test_printmd(test_config): + session = Sessions() + with patch("dfcx_scrapi.core.sessions.display") as mock_display: + session.printmd("test string") + mock_display.assert_called() + +def test_build_query_input(test_config): + session = Sessions() + query_input = session._build_query_input( + text=test_config["text"], language_code=test_config["language_code"] + ) + assert isinstance(query_input, types.session.QueryInput) + assert query_input.text.text == test_config["text"] + assert query_input.language_code == test_config["language_code"] + + +def test_build_intent_query_input(test_config): + session = Sessions() + query_input = session.build_intent_query_input( + intent_id=test_config["intent_id"], language_code=test_config["language_code"] + ) + assert isinstance(query_input, types.session.QueryInput) + assert query_input.intent.intent == test_config["intent_id"] + assert query_input.language_code == test_config["language_code"] + +def test_get_tool_action(test_config): + session = Sessions() + mock_tool_use = types.example.ToolUse(action=test_config["tool_action"]) + action = session.get_tool_action(mock_tool_use) + + assert action == test_config["tool_action"] + +def test_get_tool_params(test_config): + session = Sessions() + tool_use = types.ToolUse() + tool_use.input_action_parameters = test_config["tool_input_params"] + res = session.get_tool_params(tool_use.input_action_parameters) + assert res == test_config["tool_input_params"] + +def test_get_tool_params_empty(): + session = Sessions() + params = {} + res = session.get_tool_params(params) + assert res == {} + +def test_get_tool_params_nested(test_config): + session = Sessions() + tool_use = types.ToolUse() + tool_use.input_action_parameters = test_config["tool_input_params_nested"] + res = session.get_tool_params(tool_use.input_action_parameters) + assert res == test_config["tool_input_params_nested"] + +def test_get_playbook_name(test_config, monkeypatch): + mock_playbooks_instance = MagicMock() + mock_playbooks_instance.get_playbooks_map.return_value = { + test_config["playbook_id"]: test_config["playbook_name"] + } + + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_playbooks_client", mock_playbooks_instance) + + name = session.get_playbook_name(playbook_id=test_config["playbook_id"]) + assert name == test_config["playbook_name"] + + +def test_get_tool_name(test_config, mock_tools_instance, monkeypatch): + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_tools_client", mock_tools_instance) + + mock_tool_use = types.example.ToolUse(tool=test_config["tool_id"]) + name = session.get_tool_name(tool_use=mock_tool_use) + assert name == "mock tool" + +def test_get_flow_name(test_config, monkeypatch): + mock_flows_instance = MagicMock() + mock_flows_instance.get_flows_map.return_value = { + test_config["flow_id"]: "mock flow" + } + + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_flows_client", mock_flows_instance) + + name = session.get_flow_name(flow_id=test_config["flow_id"]) + assert name == "mock flow" + +def test_collect_tool_responses( + test_config, mock_tools_instance, + mock_query_result_tools_playbooks_flows, monkeypatch): + + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_tools_client", mock_tools_instance) + tool_responses = session.collect_tool_responses( + mock_query_result_tools_playbooks_flows) + + assert len(tool_responses) == 1 + assert tool_responses[0]["tool_name"] == "mock tool" + assert tool_responses[0]["tool_action"] == test_config["tool_action"] + assert tool_responses[0]["input_params"] == test_config["tool_input_params"] + assert tool_responses[0]["output_params"] == test_config["tool_output_params"] + +def test_collect_playbook_responses( + test_config, mock_query_result_tools_playbooks_flows, monkeypatch): + mock_playbooks_instance = MagicMock() + mock_playbooks_instance.get_playbooks_map.return_value = { + test_config["playbook_id"]: test_config["playbook_name"] + } + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_playbooks_client", mock_playbooks_instance) + playbook_responses = session.collect_playbook_responses( + mock_query_result_tools_playbooks_flows) + + assert len(playbook_responses) == 4 + assert playbook_responses[0]["playbook_name"] == test_config["playbook_name"] + +def test_collect_playbook_responses_no_playbook_invocation( + test_config, monkeypatch): + mock_playbooks_instance = MagicMock() + mock_playbooks_instance.get_playbooks_map.return_value = { + test_config["playbook_id"]: test_config["playbook_name"] + } + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_playbooks_client", mock_playbooks_instance) + + mock_query_result = types.session.QueryResult( + generative_info = types.session.GenerativeInfo( + action_tracing_info=types.Example( + actions=[], + ), + current_playbooks=[test_config["playbook_id"]] + ) + ) + playbook_responses = session.collect_playbook_responses(mock_query_result) + + assert len(playbook_responses) == 0 + +def test_collect_flow_responses(test_config, monkeypatch): + mock_flows_instance = MagicMock() + mock_flows_instance.get_flows_map.return_value = { + test_config["flow_id"]: "mock flow" + } + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_flows_client", mock_flows_instance) + + mock_query_result = types.session.QueryResult( + generative_info = types.session.GenerativeInfo( + action_tracing_info=types.Example( + actions=[ + types.Action( + flow_invocation=types.example.FlowInvocation( + flow=test_config["flow_id"] + ), + ), + ], + ), + ) + ) + flow_responses = session.collect_flow_responses(mock_query_result) + assert len(flow_responses) == 1 + assert flow_responses[0]["flow_name"] == "mock flow" + +@patch("uuid.uuid4") +def test_build_session_id_no_environment(mock_uuid4, test_config): + mock_uuid4.return_value = uuid.UUID("a1b2c3d4-e5f6-7890-1234-567890abcdef") + + session = Sessions() + session_id = session.build_session_id(agent_id=test_config["agent_id"]) + + assert session_id == test_config["session_id_plain"] + assert session.session_id == test_config["session_id_plain"] + +@patch("uuid.uuid4") +def test_build_session_id_with_environment( + mock_uuid4, test_config, mock_environment_obj, monkeypatch): + mock_uuid4.return_value = uuid.UUID("a1b2c3d4-e5f6-7890-1234-567890abcdef") + + mock_env_instance = MagicMock() + mock_env_instance.get_environment_by_display_name.return_value = mock_environment_obj + + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_env_client", mock_env_instance) + + session_id = session.build_session_id( + agent_id=test_config["agent_id"], + environment_name=test_config["environment_name"] + ) + + assert test_config["environment_id"] in session_id + assert session.session_id == session_id + assert session.session_id == test_config["session_id_with_env"] + +def test_build_session_id_invalid_env(test_config, monkeypatch): + mock_env_instance = MagicMock() + mock_env_instance.get_environment_by_display_name.return_value = None + + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "_env_client", mock_env_instance) + + with pytest.raises(ValueError, match="Environment `BAD_ENV` does not exist."): + session.build_session_id( + agent_id=test_config["agent_id"], + environment_name="BAD_ENV" + ) + +@patch("uuid.uuid4") +def test_build_session_id_no_overwrite(mock_uuid4, test_config): + mock_uuid4.return_value = uuid.UUID("a1b2c3d4-9999-9999-9999-567890abcdef") + + session = Sessions(session_id=test_config["session_id_plain"]) + session_id = session.build_session_id(agent_id=test_config["agent_id"], overwrite=False) + + assert session.session_id == test_config["session_id_plain"] + assert session_id != session.session_id + +def test_detect_intent( + test_config, mock_sessions_client, mock_detect_intent_response +): + mock_detect_intent = MagicMock(return_value=mock_detect_intent_response) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions() + query_result = session.detect_intent( + agent_id=test_config["agent_id"], + session_id=test_config["session_id_plain"], + text=test_config["text"], + language_code=test_config["language_code"], + ) + + assert query_result == mock_detect_intent_response.query_result + assert isinstance(query_result, types.QueryResult) + assert query_result.language_code == "en" #default lang code + + mock_detect_intent.assert_called_once() + request = mock_detect_intent.call_args.kwargs['request'] + assert not request.query_params.end_user_metadata + assert not request.query_params.parameters + assert not request.query_params.time_zone + +def test_detect_intent_params_and_end_user_metadata( + test_config, mock_sessions_client, mock_detect_intent_response +): + mock_detect_intent = MagicMock(return_value=mock_detect_intent_response) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions() + query_result = session.detect_intent( + agent_id=test_config["agent_id"], + session_id=test_config["session_id_plain"], + text=test_config["text"], + language_code=test_config["language_code"], + parameters=test_config["parameters"], + end_user_metadata=test_config["end_user_metadata"] + ) + + assert query_result == mock_detect_intent_response.query_result + assert isinstance(query_result, types.QueryResult) + assert query_result.language_code == "en" #default lang code + + mock_detect_intent.assert_called_once() + request = mock_detect_intent.call_args.kwargs['request'] + assert request.query_params.end_user_metadata == test_config["end_user_metadata"] + assert request.query_params.parameters == test_config["parameters"] + assert not request.query_params.time_zone + + +def test_detect_intent_with_timezone( + test_config, mock_sessions_client, mock_detect_intent_response +): + mock_detect_intent = MagicMock(return_value=mock_detect_intent_response) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions() + query_result = session.detect_intent( + agent_id=test_config["agent_id"], + session_id=test_config["session_id_plain"], + text=test_config["text"], + language_code=test_config["language_code"], + timezone="America/Los_Angeles" + ) + + assert query_result == mock_detect_intent_response.query_result + assert isinstance(query_result, types.QueryResult) + assert query_result.language_code == "en" #default lang code + + mock_detect_intent.assert_called_once() + request = mock_detect_intent.call_args.kwargs['request'] + assert request.query_params.time_zone == "America/Los_Angeles" + +# TODO (pmarlow): Tracking b/384222123 which causes Data Store Signals to +# "fail open", meaning they are always populated. Revise tests once this is +# resolved. +def test_detect_intent_with_data_store_signals( + test_config, mock_sessions_client, mock_detect_intent_response +): + mock_detect_intent = MagicMock(return_value=mock_detect_intent_response) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions() + query_result = session.detect_intent( + agent_id=test_config["agent_id"], + session_id=test_config["session_id_plain"], + text=test_config["text"], + language_code=test_config["language_code"], + populate_data_store_connection_signals=True + ) + + assert query_result == mock_detect_intent_response.query_result + + mock_detect_intent.assert_called_once() + request = mock_detect_intent.call_args.kwargs['request'] + assert request.query_params.populate_data_store_connection_signals + +def test_detect_intent_with_intent_id( + test_config, mock_sessions_client, mock_detect_intent_response +): + mock_detect_intent = MagicMock(return_value=mock_detect_intent_response) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions() + query_result = session.detect_intent( + agent_id=test_config["agent_id"], + session_id=test_config["session_id_plain"], + intent_id=test_config["intent_id"], + language_code=test_config["language_code"] + ) + + assert query_result == mock_detect_intent_response.query_result + + mock_detect_intent.assert_called_once() + request = mock_detect_intent.call_args.kwargs['request'] + assert request.query_input.intent.intent == test_config["intent_id"] + +def test_detect_intent_invalid_session_id(test_config): + session = Sessions() + with pytest.raises( + ValueError, + match="Session ID must be provided in the following format:"): + session.detect_intent( + agent_id=test_config["agent_id"], + session_id="invalid_session_id", + text=test_config["text"], + language_code=test_config["language_code"], + ) + +def test_detect_intent_preset_parameters_no_text_input( + test_config, mock_sessions_client, + mock_detect_intent_response_no_text_input): + mock_detect_intent = MagicMock( + return_value=mock_detect_intent_response_no_text_input) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions() + query_result = session.detect_intent( + agent_id=test_config["agent_id"], + session_id=test_config["session_id_plain"], + parameters=test_config["parameters"], + ) + + assert query_result == mock_detect_intent_response_no_text_input.query_result + mock_detect_intent.assert_called_once() + request = mock_detect_intent.call_args.kwargs['request'] + assert request.query_params.parameters == test_config["parameters"] + assert not request.query_input.text.text + +def test_preset_parameters_no_agent_id(test_config): + session = Sessions() + with pytest.raises(ValueError, match="resource_id must not be None"): + session.detect_intent( + agent_id=None, + session_id=test_config["session_id_plain"], + parameters=test_config["parameters"], + ) + +def test_preset_parameters_no_session_id(test_config): + session = Sessions() + with pytest.raises(ValueError, match="Session ID must be provided in the following format:"): + session.detect_intent( + agent_id=test_config["agent_id"], + session_id=None, + parameters=test_config["parameters"], + ) + +def test_get_agent_answer(test_config, mock_sessions_client, monkeypatch, mock_query_result_datastore): + mock_detect_intent_response = MagicMock() + mock_detect_intent_response.query_result = mock_query_result_datastore + + mock_detect_intent = MagicMock(return_value=mock_detect_intent_response) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "build_session_id", + MagicMock(return_value=test_config["session_id_plain"])) + + answer = session.get_agent_answer(user_query="who is the ceo?") + assert answer == "Sundar Pichai is the CEO of Google.\nhttps://www.google.com (https://www.google.com)" + mock_detect_intent.assert_called_once() + +def test_get_agent_answer_no_citation(test_config, mock_sessions_client, monkeypatch): + mock_detect_intent_response = MagicMock() + mock_detect_intent_response.query_result = types.session.QueryResult( + response_messages=[ + types.ResponseMessage( + text=types.ResponseMessage.Text( + text=["Test Response"] + ) + ) + ] + ) + + mock_detect_intent = MagicMock(return_value=mock_detect_intent_response) + mock_sessions_client.return_value.detect_intent = mock_detect_intent + + session = Sessions(agent_id=test_config["agent_id"]) + monkeypatch.setattr(session, "build_session_id", + MagicMock(return_value=test_config["session_id_plain"])) + + answer = session.get_agent_answer(user_query="test query") + assert answer == "Test Response ()" + mock_detect_intent.assert_called_once() + +def test_parse_result( + test_config, mock_query_result_tools_playbooks_flows, monkeypatch +): + session = Sessions(agent_id=test_config["agent_id"]) + mock_printmd = MagicMock() + monkeypatch.setattr(session, "printmd", mock_printmd) + + session.parse_result(mock_query_result_tools_playbooks_flows) + + mock_printmd.assert_called() + calls = mock_printmd.mock_calls + + tool_call_font = "TOOL CALL:" + tool_res_font = "TOOL RESULT:" + query_font = " USER QUERY:" + response_font = "AGENT RESPONSE:" + + # Check for the user query + assert any(query_font in call.args[0] for call in calls) + assert any(tool_call_font in call.args[0] for call in calls) + assert any(tool_res_font in call.args[0] for call in calls) + assert any(response_font in call.args[0] for call in calls) diff --git a/tests/dfcx_scrapi/core/test_test_cases.py b/tests/dfcx_scrapi/core/test_test_cases.py index c2c46900..4efefae1 100644 --- a/tests/dfcx_scrapi/core/test_test_cases.py +++ b/tests/dfcx_scrapi/core/test_test_cases.py @@ -334,7 +334,7 @@ def test_list_test_cases_agent_id_not_in_instance( tc = PyTestCases() - with pytest.raises(AttributeError): + with pytest.raises(ValueError): _ = tc.list_test_cases()