diff --git a/src/dfcx_scrapi/core/scrapi_base.py b/src/dfcx_scrapi/core/scrapi_base.py index c9335653..450550ee 100644 --- a/src/dfcx_scrapi/core/scrapi_base.py +++ b/src/dfcx_scrapi/core/scrapi_base.py @@ -1,6 +1,6 @@ """Base for other SCRAPI classes.""" -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,8 +22,9 @@ import threading import vertexai from collections import defaultdict -from typing import Dict, Any, Iterable +from typing import Dict, Any, Iterable, List +from google.auth import default from google.api_core import exceptions from google.cloud.dialogflowcx_v3beta1 import types from google.oauth2 import service_account @@ -68,24 +69,25 @@ ALL_GENERATIVE_MODELS = ALL_GEMINI_MODELS + TEXT_GENERATION_MODELS +# Define global scopes used for all Dialogflow CX Requests +GLOBAL_SCOPES = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/dialogflow", + ] + class ScrapiBase: """Core Class for managing Auth and other shared functions.""" - global_scopes = [ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/dialogflow", - ] - def __init__( self, creds_path: str = None, creds_dict: Dict[str, str] = None, creds: service_account.Credentials = None, - scope=False, - agent_id=None, + scope: List[str] = None, + agent_id: str = None, ): - self.scopes = ScrapiBase.global_scopes + self.scopes = GLOBAL_SCOPES if scope: self.scopes += scope @@ -93,24 +95,28 @@ def __init__( self.creds = creds self.creds.refresh(Request()) self.token = self.creds.token + elif creds_path: self.creds = service_account.Credentials.from_service_account_file( creds_path, scopes=self.scopes ) self.creds.refresh(Request()) self.token = self.creds.token + elif creds_dict: self.creds = service_account.Credentials.from_service_account_info( creds_dict, scopes=self.scopes ) self.creds.refresh(Request()) self.token = self.creds.token + else: - self.creds = None - self.token = None + self.creds, _ = default() + self.creds.refresh(Request()) + self.token = self.creds.token + self._check_and_update_scopes(self.creds) - if agent_id: - self.agent_id = agent_id + self.agent_id = agent_id self.api_calls_dict = defaultdict(int) @@ -223,13 +229,22 @@ def dict_to_struct(some_dict: Dict[str, Any]): @staticmethod def parse_agent_id(resource_id: str): """Attempts to parse Agent ID from provided Resource ID.""" - try: - agent_id = "/".join(resource_id.split("/")[:6]) - except IndexError as err: - logging.error("IndexError - path too short? %s", resource_id) - raise err + parts = resource_id.split("/") + if len(parts) < 6: + raise ValueError( + "Resource ID is too short to contain an Agent ID: {}".format( + resource_id + ) + ) - return agent_id + if parts[4] != "agents": + raise ValueError( + "Resource ID does not contain an agent ID: {}".format( + resource_id + ) + ) + + return "/".join(parts[:6]) @staticmethod def _parse_resource_path( @@ -400,6 +415,14 @@ def is_valid_sys_instruct_model(llm_model: str) -> bool: return valid_sys_instruct + def _check_and_update_scopes(self, creds: Any): + """Update Credentials scopes if possible based on creds type.""" + if creds.requires_scopes: + self.creds.scopes.extend(GLOBAL_SCOPES) + + else: + logging.info("Found user creds, skipping global scopes...") + def build_generative_model( self, llm_model: str, diff --git a/src/dfcx_scrapi/tools/dataframe_functions.py b/src/dfcx_scrapi/tools/dataframe_functions.py index d9a233a6..86e53eac 100644 --- a/src/dfcx_scrapi/tools/dataframe_functions.py +++ b/src/dfcx_scrapi/tools/dataframe_functions.py @@ -17,7 +17,7 @@ import json import logging import time -from typing import Dict, List +from typing import Dict, List, Any import gspread import pandas as pd import numpy as np @@ -33,7 +33,7 @@ from dfcx_scrapi.core.pages import Pages from dfcx_scrapi.core.transition_route_groups import TransitionRouteGroups -GLOBAL_SCOPE = [ +SHEETS_SCOPE = [ "https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive", ] @@ -52,8 +52,8 @@ def __init__( self, creds_path: str = None, creds_dict: dict = None, - creds=None, - scope=False, + creds: Any = None, + scope: List[str] = None, ): super().__init__( creds_path=creds_path, @@ -62,14 +62,16 @@ def __init__( scope=scope, ) - scopes = GLOBAL_SCOPE + self._check_and_update_sheets_scopes() - if scope: - scopes += scope - - self.creds.scopes.extend(scopes) + if hasattr(self.creds, "service_account_email") and self.creds.service_account_email: + self.sheets_client = gspread.authorize(self.creds) + else: + logging.warning( + "Application Default Credentials (ADC) found and Sheets Client" + " could not be authorized. Use Service Account or Oauth2 user" + " credentials if you require Sheets access.") - self.sheets_client = gspread.authorize(self.creds) self.entities = EntityTypes(creds=self.creds) self.intents = Intents(creds=self.creds) self.flows = Flows(creds=self.creds) @@ -138,6 +140,11 @@ def _remap_intent_values(original_intent: types.Intent) -> types.Intent: return new_intent + def _check_and_update_sheets_scopes(self): + """Update Credentials scopes if possible based on creds type.""" + if self.creds.requires_scopes: + self.creds.scopes.extend(SHEETS_SCOPE) + def _update_intent_from_dataframe( self, intent_id: str, diff --git a/tests/dfcx_scrapi/core/test_agents.py b/tests/dfcx_scrapi/core/test_agents.py index 03c1aa77..856edd45 100644 --- a/tests/dfcx_scrapi/core/test_agents.py +++ b/tests/dfcx_scrapi/core/test_agents.py @@ -17,10 +17,13 @@ # limitations under the License. import pytest -from unittest.mock import patch +from typing import Dict +from unittest.mock import patch, MagicMock from google.cloud.dialogflowcx_v3beta1 import types -from google.cloud.dialogflowcx_v3beta1 import services +from google.cloud.dialogflowcx_v3beta1.services.agents import ( + pagers, AgentsClient + ) from dfcx_scrapi.core.agents import Agents @@ -44,7 +47,7 @@ def test_config(): } @pytest.fixture -def mock_agent_obj_flow(test_config): +def mock_agent_obj_flow(test_config: Dict[str, str]): return types.Agent( name=test_config["agent_id"], display_name=test_config["display_name"], @@ -54,7 +57,7 @@ def mock_agent_obj_flow(test_config): ) @pytest.fixture -def mock_agent_obj_playbook(test_config): +def mock_agent_obj_playbook(test_config: Dict[str, str]): return types.Agent( name=test_config["agent_id"], display_name=test_config["display_name"], @@ -64,36 +67,48 @@ def mock_agent_obj_playbook(test_config): ) @pytest.fixture -def mock_agent_obj_kwargs(mock_agent_obj_flow): +def mock_agent_obj_kwargs(mock_agent_obj_flow: types.Agent): mock_agent_obj_flow.description = "This is a Mock Agent description." mock_agent_obj_flow.enable_stackdriver_logging = True return mock_agent_obj_flow @pytest.fixture -def mock_updated_agent_obj(mock_agent_obj_flow): +def mock_updated_agent_obj(mock_agent_obj_flow: types.Agent): mock_agent_obj_flow.display_name = "Updated Agent Display Name" return mock_agent_obj_flow @pytest.fixture -def mock_list_agents_response(mock_agent_obj_flow): +def mock_list_agents_response(mock_agent_obj_flow: types.Agent): return types.agent.ListAgentsResponse(agents=[mock_agent_obj_flow]) @pytest.fixture -def mock_list_agents_pager(mock_list_agents_response): - return services.agents.pagers.ListAgentsPager( - services.agents.AgentsClient.list_agents, +def mock_list_agents_pager(mock_list_agents_response: types.ListAgentsResponse): + return pagers.ListAgentsPager( + AgentsClient.list_agents, types.agent.ListAgentsRequest(), mock_list_agents_response, ) +@pytest.fixture(autouse=True) +def mock_client(test_config: Dict[str, str]): + """Setup fixture for Agents Class to be used with all tests.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") as mock_client: + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client # Return control to test method + # Test list_agents with location_id -@patch("dfcx_scrapi.core.agents.Agents._list_agents_client_request") -def test_list_agents_with_location(mock_list_agents_client_request, - mock_agent_obj_flow, - test_config - ): - mock_list_agents_client_request.return_value = [mock_agent_obj_flow] +def test_list_agents_with_location( + mock_client: MagicMock, + mock_list_agents_pager: pagers.ListAgentsPager, + test_config: Dict[str, str]): + mock_client.return_value.list_agents.return_value = mock_list_agents_pager + agent = Agents() agents = agent.list_agents( project_id=test_config["project_id"], @@ -104,12 +119,12 @@ def test_list_agents_with_location(mock_list_agents_client_request, assert agents[0].name == test_config["agent_id"] # Test list_agents without location_id -@patch("dfcx_scrapi.core.agents.Agents._list_agents_client_request") -def test_list_agents_without_location(mock_list_agents_client_request, - mock_agent_obj_flow, - test_config - ): - mock_list_agents_client_request.return_value = [mock_agent_obj_flow] +def test_list_agents_without_location( + mock_client: MagicMock, + mock_list_agents_pager: pagers.ListAgentsPager, + test_config: Dict[str, str]): + mock_client.return_value.list_agents.return_value = mock_list_agents_pager + agent = Agents() agents = agent.list_agents(project_id=test_config["project_id"]) assert isinstance(agents, list) @@ -117,8 +132,10 @@ def test_list_agents_without_location(mock_list_agents_client_request, assert agents[0].name == test_config["agent_id"] # Test get_agent -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") -def test_get_agent(mock_client, mock_agent_obj_flow, test_config): +def test_get_agent( + mock_client: MagicMock, + mock_agent_obj_flow: types.Agent, + test_config: Dict[str, str]): mock_client.return_value.get_agent.return_value = mock_agent_obj_flow agent = Agents() response = agent.get_agent(test_config["agent_id"]) @@ -126,9 +143,11 @@ def test_get_agent(mock_client, mock_agent_obj_flow, test_config): assert response.name == test_config["agent_id"] assert response.display_name == test_config["display_name"] -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") def test_get_agent_by_display_name_no_location( - mock_client, mock_agent_obj_flow, mock_list_agents_pager, test_config): + mock_client: MagicMock, + mock_agent_obj_flow: types.Agent, + mock_list_agents_pager: pagers.ListAgentsPager, + test_config: Dict[str, str]): mock_client.return_value.get_agent_by_display_name.return_value = mock_agent_obj_flow # pylint: disable=C0301 mock_client.return_value.list_agents.return_value = mock_list_agents_pager agent = Agents() @@ -138,9 +157,11 @@ def test_get_agent_by_display_name_no_location( assert response is None -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") def test_get_agent_by_display_name_with_region( - mock_client, mock_agent_obj_flow, mock_list_agents_pager, test_config): + mock_client: MagicMock, + mock_agent_obj_flow: types.Agent, + mock_list_agents_pager: pagers.ListAgentsPager, + test_config: Dict[str, str]): mock_client.return_value.get_agent_by_display_name.return_value = mock_agent_obj_flow # pylint: disable=C0301 mock_client.return_value.list_agents.return_value = mock_list_agents_pager agent = Agents() @@ -155,9 +176,10 @@ def test_get_agent_by_display_name_with_region( assert response.display_name == test_config["display_name"] # Test create_agent -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") def test_create_agent_with_kwargs( - mock_client, mock_agent_obj_flow, test_config): + mock_client: MagicMock, + mock_agent_obj_flow: types.Agent, + test_config: Dict[str, str]): mock_client.return_value.create_agent.return_value = mock_agent_obj_flow agent = Agents() response = agent.create_agent( @@ -168,8 +190,10 @@ def test_create_agent_with_kwargs( assert response.name == test_config["agent_id"] assert response.display_name == test_config["display_name"] -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") -def test_create_agent_from_obj(mock_client, mock_agent_obj_flow, test_config): +def test_create_agent_from_obj( + mock_client: MagicMock, + mock_agent_obj_flow: types.Agent, + test_config: Dict[str, str]): mock_client.return_value.create_agent.return_value = mock_agent_obj_flow agents = Agents() @@ -180,8 +204,9 @@ def test_create_agent_from_obj(mock_client, mock_agent_obj_flow, test_config): assert res.display_name == test_config["display_name"] -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") -def test_create_agent_from_obj_with_kwargs(mock_client, mock_agent_obj_kwargs): +def test_create_agent_from_obj_with_kwargs( + mock_client: MagicMock, + mock_agent_obj_kwargs: types.Agent): mock_client.return_value.create_agent.return_value = mock_agent_obj_kwargs agents = Agents() @@ -196,10 +221,10 @@ def test_create_agent_from_obj_with_kwargs(mock_client, mock_agent_obj_kwargs): assert res.description == "This is a Mock Agent description." # Test update_agent -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") -def test_update_agent_with_obj(mock_client, - mock_updated_agent_obj, - test_config): +def test_update_agent_with_obj( + mock_client: MagicMock, + mock_updated_agent_obj: types.Agent, + test_config: Dict[str, str]): mock_client.return_value.update_agent.return_value = ( mock_updated_agent_obj ) @@ -212,10 +237,9 @@ def test_update_agent_with_obj(mock_client, assert response.name == test_config["agent_id"] assert response.display_name == "Updated Agent Display Name" -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") -def test_update_agent_with_kwargs(mock_client, - mock_agent_obj_flow, - test_config): +def test_update_agent_with_kwargs(mock_client: MagicMock, + mock_agent_obj_flow: types.Agent, + test_config: Dict[str, str]): mock_client.return_value.get_agent.return_value = mock_agent_obj_flow mock_client.return_value.update_agent.return_value = mock_agent_obj_flow agent = Agents() @@ -229,18 +253,17 @@ def test_update_agent_with_kwargs(mock_client, assert response.display_name == "Updated Agent Display Name" # Test delete_agent -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") -def test_delete_agent(test_config): +def test_delete_agent(test_config: Dict[str, str]): agent = Agents() response = agent.delete_agent(agent_id=test_config["agent_id"]) - print(response) assert ( response == f"Agent '{test_config['agent_id']}' successfully deleted." ) -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") def test_create_agent_simple_default_region_no_kwargs( - mock_client, mock_agent_obj_flow, test_config + mock_client: MagicMock, + mock_agent_obj_flow: types.Agent, + test_config: Dict[str, str] ): mock_client.return_value.create_agent.return_value = mock_agent_obj_flow @@ -254,8 +277,9 @@ def test_create_agent_simple_default_region_no_kwargs( assert res.display_name == test_config["display_name"] -@patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") -def test_create_agent_with_extra_kwargs(mock_client, mock_agent_obj_kwargs): +def test_create_agent_with_extra_kwargs( + mock_client: MagicMock, + mock_agent_obj_kwargs: types.Agent): mock_client.return_value.create_agent.return_value = mock_agent_obj_kwargs agents = Agents() @@ -267,4 +291,4 @@ def test_create_agent_with_extra_kwargs(mock_client, mock_agent_obj_kwargs): ) assert isinstance(res, types.Agent) - assert res == mock_agent_obj_kwargs + assert res == mock_agent_obj_kwargs \ No newline at end of file diff --git a/tests/dfcx_scrapi/core/test_conversation_history.py b/tests/dfcx_scrapi/core/test_conversation_history.py index 4a88eb48..7fdf6309 100644 --- a/tests/dfcx_scrapi/core/test_conversation_history.py +++ b/tests/dfcx_scrapi/core/test_conversation_history.py @@ -21,7 +21,7 @@ import os import json import pytest -from unittest.mock import patch +from unittest.mock import patch, MagicMock from dfcx_scrapi.core.conversation_history import ConversationHistory from google.cloud.dialogflowcx_v3beta1 import types @@ -31,9 +31,13 @@ @pytest.fixture def test_config(): - agent_id = "projects/mock-test/locations/global/agents/a1s2d3f4" + project_id = "my-project-id-1234" + location_id = "global" + parent = f"projects/{project_id}/locations/{location_id}" + agent_id = f"{parent}/agents/my-agent-1234" conversation_id = f"{agent_id}/conversations/1234" return { + "project_id": project_id, "agent_id": agent_id, "conversation_id": conversation_id, } @@ -82,6 +86,18 @@ def mock_list_conversations_pager(test_conversation): ), ) +@pytest.fixture(autouse=True) +def mock_client(test_config): + """Setup mock client for all tests.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.conversation_history.services.conversation_history.ConversationHistoryClient") as mock_client: + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client # Return control to test method + # Test get_user_input def test_get_user_input(): query_input = types.QueryInput(text=types.TextInput(text="test input")) @@ -98,10 +114,6 @@ def test_get_query_result(): assert agent_response == "test result" # Test list_conversations -@patch( - "dfcx_scrapi.core.conversation_history.services.conversation_history" - ".ConversationHistoryClient" -) def test_list_conversations( mock_client, mock_list_conversations_pager, test_config ): @@ -114,10 +126,6 @@ def test_list_conversations( assert isinstance(res[0], types.Conversation) # Test get_conversation -@patch( - "dfcx_scrapi.core.conversation_history.services.conversation_history" - ".ConversationHistoryClient" -) def test_get_conversation( mock_client, test_conversation, test_config ): @@ -131,10 +139,6 @@ def test_get_conversation( assert isinstance(res, types.Conversation) # Test delete_conversation -@patch( - "dfcx_scrapi.core.conversation_history.services.conversation_history" - ".ConversationHistoryClient" -) def test_delete_conversation(mock_client, test_config): ch = ConversationHistory(agent_id=test_config["agent_id"]) ch.delete_conversation( @@ -173,10 +177,6 @@ def test_read_conversations_from_file(tmpdir): assert loaded_data == data # Test conversation_history_to_file -@patch( - "dfcx_scrapi.core.conversation_history.services.conversation_history" - ".ConversationHistoryClient" -) @patch("dfcx_scrapi.core.conversation_history.thread_map") def test_conversation_history_to_file( mock_thread_map, mock_client, test_conversation, tmpdir, test_config diff --git a/tests/dfcx_scrapi/core/test_examples.py b/tests/dfcx_scrapi/core/test_examples.py index a6dcfe76..f7ce5676 100644 --- a/tests/dfcx_scrapi/core/test_examples.py +++ b/tests/dfcx_scrapi/core/test_examples.py @@ -19,19 +19,23 @@ # limitations under the License. import pytest -from unittest.mock import MagicMock +from unittest.mock import patch, MagicMock from dfcx_scrapi.core.examples import Examples from google.cloud.dialogflowcx_v3beta1 import types from google.cloud.dialogflowcx_v3beta1 import services @pytest.fixture def test_config(): - agent_id = "projects/mock-test/locations/global/agents/a1s2d3f4" + project_id = "my-project-id-1234" + location_id = "global" + parent = f"projects/{project_id}/locations/{location_id}" + agent_id = f"{parent}/agents/my-agent-1234" playbook_id = f"{agent_id}/playbooks/1234" example_id = f"{playbook_id}/examples/9876" tool_id = f"{agent_id}/tools/4321" display_name = "test_example" return { + "project_id": project_id, "agent_id": agent_id, "playbook_id": playbook_id, "example_id": example_id, @@ -89,38 +93,26 @@ def mock_list_examples_pager(mock_example_obj): types.example.ListExamplesResponse(examples=[mock_example_obj]), ) -@pytest.fixture -def mock_examples(monkeypatch, test_config): - """Fixture to create Example object w/ mocked ExmamplesClient.""" - mock_examples_client = MagicMock() - - # Override / Intercept Playbook/Tool instantiation in Examples init. - def mock_playbooks_init(self, *args, **kwargs): - pass - - def mock_tools_init(self, *args, **kwargs): - pass - - monkeypatch.setattr( - "dfcx_scrapi.core.examples.services.examples.ExamplesClient", - mock_examples_client - ) - monkeypatch.setattr( - "dfcx_scrapi.core.playbooks.Playbooks.__init__", - mock_playbooks_init - ) - monkeypatch.setattr( - "dfcx_scrapi.core.tools.Tools.__init__", - mock_tools_init - ) +@pytest.fixture(autouse=True) +def mock_client(test_config): + """Fixture to create mocked ExamplesClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") as mock_client, \ + patch("dfcx_scrapi.core.playbooks.Playbooks.__init__") as mock_playbooks_init, \ + patch("dfcx_scrapi.core.tools.Tools.__init__") as mock_tools_init: - examples = Examples(agent_id=test_config["agent_id"]) + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + mock_playbooks_init.return_value = None + mock_tools_init.return_value = None - yield examples, mock_examples_client + yield mock_client # Test get_examples_map -def test_get_examples_map(mock_examples, mock_list_examples_pager, test_config): - ex, mock_client = mock_examples +def test_get_examples_map(mock_client, mock_list_examples_pager, test_config): + ex = Examples(agent_id=test_config["agent_id"]) mock_client.return_value.list_examples.return_value = ( mock_list_examples_pager ) @@ -133,8 +125,8 @@ def test_get_examples_map(mock_examples, mock_list_examples_pager, test_config): print(mock_client.mock_calls) # Test list_examples -def test_list_examples(mock_examples, mock_list_examples_pager, test_config): - ex, mock_client = mock_examples +def test_list_examples(mock_client, mock_list_examples_pager, test_config): + ex = Examples(agent_id=test_config["agent_id"]) mock_client.return_value.list_examples.return_value = ( mock_list_examples_pager ) @@ -144,8 +136,8 @@ def test_list_examples(mock_examples, mock_list_examples_pager, test_config): assert isinstance(res[0], types.Example) # Test get_example -def test_get_example(mock_examples, mock_example_obj, test_config): - ex, mock_client = mock_examples +def test_get_example(mock_client, mock_example_obj, test_config): + ex = Examples(agent_id=test_config["agent_id"]) mock_client.return_value.get_example.return_value = mock_example_obj res = ex.get_example(example_id=test_config["example_id"]) @@ -154,8 +146,8 @@ def test_get_example(mock_examples, mock_example_obj, test_config): # Test create_example def test_create_example_from_kwargs( - mock_examples, mock_example_obj, test_config): - ex, mock_client = mock_examples + mock_client, mock_example_obj, test_config): + ex = Examples(agent_id=test_config["agent_id"]) mock_client.return_value.create_example.return_value = mock_example_obj res = ex.create_example( playbook_id=test_config["playbook_id"], @@ -165,8 +157,8 @@ def test_create_example_from_kwargs( assert res.display_name == test_config["display_name"] def test_create_example_from_proto_object( - mock_examples, mock_example_obj, test_config): - ex, mock_client = mock_examples + mock_client, mock_example_obj, test_config): + ex = Examples(agent_id=test_config["agent_id"]) mock_client.return_value.create_example.return_value = mock_example_obj res = ex.create_example( playbook_id=test_config["playbook_id"], @@ -177,8 +169,8 @@ def test_create_example_from_proto_object( # Test update_example def test_update_example_with_obj( - mock_examples, mock_updated_example_obj, test_config): - ex, mock_client = mock_examples + mock_client, mock_updated_example_obj, test_config): + ex = Examples(agent_id=test_config["agent_id"]) mock_client.return_value.update_example.return_value = ( mock_updated_example_obj ) @@ -191,8 +183,8 @@ def test_update_example_with_obj( assert res.display_name == "updated_test_example" def test_update_example_with_kwargs( - mock_examples, mock_example_obj, test_config): - ex, mock_client = mock_examples + mock_client, mock_example_obj, test_config): + ex = Examples(agent_id=test_config["agent_id"]) mock_client.return_value.get_example.return_value = mock_example_obj mock_client.return_value.update_example.return_value = mock_example_obj res = ex.update_example( @@ -204,14 +196,14 @@ def test_update_example_with_kwargs( assert res.display_name == "updated_test_example" # Test delete_example -def test_delete_example(mock_examples, test_config): - ex, mock_client = mock_examples +def test_delete_example(mock_client, test_config): + ex = Examples(agent_id=test_config["agent_id"]) ex.delete_example(example_id=test_config["example_id"]) mock_client.return_value.delete_example.assert_called() # Test get_playbook_state -def test_get_playbook_state(mock_examples): - ex, _ = mock_examples +def test_get_playbook_state(test_config): + ex = Examples(agent_id=test_config["agent_id"]) assert ex.get_playbook_state("OK") == 1 assert ex.get_playbook_state("CANCELLED") == 2 assert ex.get_playbook_state("FAILED") == 3 @@ -220,8 +212,8 @@ def test_get_playbook_state(mock_examples): assert ex.get_playbook_state(None) == 0 # Test build_example_from_action_list_dict -def test_build_example_from_action_list(mock_examples): - ex, _ = mock_examples +def test_build_example_from_action_list(test_config): + ex = Examples(agent_id=test_config["agent_id"]) action_list = [ {"user_utterance": "hello"}, {"agent_utterance": "hi there"}, @@ -234,8 +226,8 @@ def test_build_example_from_action_list(mock_examples): assert len(example.actions) == 2 # Test build_playbook_invocation -def test_build_playbook_invocation(mock_examples, test_config): - ex, _ = mock_examples +def test_build_playbook_invocation(test_config): + ex = Examples(agent_id=test_config["agent_id"]) ex.playbooks_map = {"test_playbook": test_config["playbook_id"]} action = {"playbook_name": "test_playbook"} diff --git a/tests/dfcx_scrapi/core/test_playbooks.py b/tests/dfcx_scrapi/core/test_playbooks.py index 5b572c53..3e22ed6e 100644 --- a/tests/dfcx_scrapi/core/test_playbooks.py +++ b/tests/dfcx_scrapi/core/test_playbooks.py @@ -27,7 +27,10 @@ @pytest.fixture def test_config(): - agent_id = "projects/mock-test/locations/global/agents/a1s2d3f4" + project_id = "my-project-id-1234" + location_id = "global" + parent = f"projects/{project_id}/locations/{location_id}" + agent_id = f"{parent}/agents/my-agent-1234" playbook_id = f"{agent_id}/playbooks/1234" goal = """You are a Google caliber software engineer that helps users write code.""" @@ -99,6 +102,7 @@ def test_config(): playbook_version_description = "v1.0" return { + "project_id": project_id, "agent_id": agent_id, "playbook_id": playbook_id, "goal": goal, @@ -178,28 +182,36 @@ def mock_list_playbooks_pager(mock_playbook_obj_list): ) -@pytest.fixture -def mock_playbooks(monkeypatch, test_config): - """Fixture to create a Playbooks object with a mocked PlaybooksClient.""" - mock_playbooks_client = MagicMock() - monkeypatch.setattr( - "dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient", - mock_playbooks_client - ) +@pytest.fixture(autouse=True) +def mock_client(test_config): + """Fixture to create a mocked PlaybooksClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") as mock_client, \ + patch("dfcx_scrapi.core.agents.Agents.__init__") as mock_agents_init: - mock_agents_client = MagicMock() - monkeypatch.setattr( - "dfcx_scrapi.core.agents.services.agents.AgentsClient", - mock_agents_client - ) + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + mock_agents_init.return_value = None - playbooks = Playbooks(agent_id=test_config["agent_id"]) - yield playbooks, mock_playbooks_client, mock_agents_client + yield mock_client + +@pytest.fixture +def mock_agents_client(test_config): + """Fixture to create a mocked AgentsClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.agents.services.agents.AgentsClient") as mock_client: + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + yield mock_client # Return control to test method # Test get_playbooks_map -def test_get_playbooks_map(mock_playbooks, mock_list_playbooks_pager, test_config): - pb, mock_client, _ = mock_playbooks +def test_get_playbooks_map(mock_client, mock_list_playbooks_pager, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.list_playbooks.return_value = mock_list_playbooks_pager # pylint: disable=C0301 res = pb.get_playbooks_map(agent_id=test_config["agent_id"]) @@ -209,8 +221,8 @@ def test_get_playbooks_map(mock_playbooks, mock_list_playbooks_pager, test_confi # Test list_playbooks -def test_list_playbooks(mock_playbooks, mock_list_playbooks_pager, test_config): - pb, mock_client, _ = mock_playbooks +def test_list_playbooks(mock_client, mock_list_playbooks_pager, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.list_playbooks.return_value = mock_list_playbooks_pager # pylint: disable=C0301 res = pb.list_playbooks() @@ -219,8 +231,8 @@ def test_list_playbooks(mock_playbooks, mock_list_playbooks_pager, test_config): # Test get_playbook -def test_get_playbook(mock_playbooks, mock_playbook_obj_list, test_config): - pb, mock_client, _ = mock_playbooks +def test_get_playbook(mock_client, mock_playbook_obj_list, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.get_playbook.return_value = mock_playbook_obj_list res = pb.get_playbook(playbook_id=test_config["playbook_id"]) @@ -230,8 +242,8 @@ def test_get_playbook(mock_playbooks, mock_playbook_obj_list, test_config): # Test create_playbook def test_create_playbook_from_kwargs_instruction_list( - mock_playbooks, mock_playbook_obj_list, test_config): - pb, mock_client, _ = mock_playbooks + mock_client, mock_playbook_obj_list, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.create_playbook.return_value = mock_playbook_obj_list # pylint: disable=C0301 res = pb.create_playbook( agent_id=test_config["agent_id"], @@ -244,8 +256,8 @@ def test_create_playbook_from_kwargs_instruction_list( assert res.instruction == test_config["instructions_proto_from_list"] def test_create_playbook_from_kwargs_instruction_str( - mock_playbooks, mock_playbook_obj_str, test_config): - pb, mock_client, _ = mock_playbooks + mock_client, mock_playbook_obj_str, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.create_playbook.return_value = mock_playbook_obj_str # pylint: disable=C0301 res = pb.create_playbook( agent_id=test_config["agent_id"], @@ -258,8 +270,8 @@ def test_create_playbook_from_kwargs_instruction_str( assert res.instruction == test_config["instructions_proto_from_str"] def test_create_playbook_from_proto_object( - mock_playbooks, mock_playbook_obj_list, test_config): - pb, mock_client, _ = mock_playbooks + mock_client, mock_playbook_obj_list, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.create_playbook.return_value = mock_playbook_obj_list # pylint: disable=C0301 res = pb.create_playbook( agent_id=test_config["agent_id"], @@ -271,8 +283,8 @@ def test_create_playbook_from_proto_object( # Test update_playbook def test_update_playbook_with_obj( - mock_playbooks, mock_updated_playbook_obj, test_config): - pb, mock_client, _ = mock_playbooks + mock_client, mock_updated_playbook_obj, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.update_playbook.return_value = ( mock_updated_playbook_obj ) @@ -286,8 +298,8 @@ def test_update_playbook_with_obj( def test_update_playbook_with_kwargs( - mock_playbooks, mock_playbook_obj_list, test_config): - pb, mock_client, _ = mock_playbooks + mock_client, mock_playbook_obj_list, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.get_playbook.return_value = mock_playbook_obj_list mock_client.return_value.update_playbook.return_value = mock_playbook_obj_list # pylint: disable=C0301 res = pb.update_playbook( @@ -300,8 +312,8 @@ def test_update_playbook_with_kwargs( # Test the playbook kwarg processing helper methods def test_process_playbook_kwargs_display_name( - mock_playbooks, mock_playbook_obj_str, mock_updated_playbook_obj): - pb, _, _ = mock_playbooks + mock_playbook_obj_str, mock_updated_playbook_obj, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) kwargs = {"display_name": "mock playbook updated"} expected_mask = field_mask_pb2.FieldMask(paths=["display_name"]) @@ -311,9 +323,9 @@ def test_process_playbook_kwargs_display_name( assert expected_mask == mask def test_process_playbook_kwargs_instruction_list( - mock_playbooks, mock_playbook_obj_empty_instructions, + mock_playbook_obj_empty_instructions, mock_playbook_obj_list, test_config): - pb, _, _ = mock_playbooks + pb = Playbooks(agent_id=test_config["agent_id"]) # patch the object so we can track the internal method call with patch.object( @@ -332,9 +344,9 @@ def test_process_playbook_kwargs_instruction_list( test_config["instructions_list"]) def test_process_playbook_kwargs_instruction_str( - mock_playbooks, mock_playbook_obj_empty_instructions, + mock_playbook_obj_empty_instructions, mock_playbook_obj_str, test_config): - pb, _, _ = mock_playbooks + pb = Playbooks(agent_id=test_config["agent_id"]) # patch the object so we can track the internal method call with patch.object( @@ -354,9 +366,9 @@ def test_process_playbook_kwargs_instruction_str( ) def test_process_playbook_kwargs_instruction_obj( - mock_playbooks, mock_playbook_obj_empty_instructions, + mock_playbook_obj_empty_instructions, mock_playbook_obj_str, test_config): - pb, _, _ = mock_playbooks + pb = Playbooks(agent_id=test_config["agent_id"]) kwargs = {"instructions": test_config["instructions_proto_from_str"]} expected_mask = field_mask_pb2.FieldMask(paths=["instruction"]) @@ -367,37 +379,37 @@ def test_process_playbook_kwargs_instruction_obj( assert expected_mask == mask # Test delete_playbook -def test_delete_playbook(mock_playbooks, test_config): - pb, mock_client, _ = mock_playbooks +def test_delete_playbook(mock_client, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) pb.delete_playbook(playbook_id=test_config["playbook_id"]) mock_client.return_value.delete_playbook.assert_called() # Test set_default_playbook -def test_set_default_playbook(mock_playbooks, mock_agent_obj, test_config): - pb, _, agent_client = mock_playbooks - agent_client.return_value.get_agent.return_value = mock_agent_obj - agent_client.return_value.update_agent.return_value = mock_agent_obj +def test_set_default_playbook(mock_agents_client, mock_agent_obj, test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) + mock_agents_client.return_value.get_agent.return_value = mock_agent_obj + mock_agents_client.return_value.update_agent.return_value = mock_agent_obj pb.set_default_playbook(playbook_id=test_config["playbook_id"]) assert mock_agent_obj.start_playbook == test_config["playbook_id"] # Test build instruction helpers -def test_build_instructions_from_list(mock_playbooks, test_config): - pb, _, _ = mock_playbooks +def test_build_instructions_from_list(test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) res = pb.build_instructions_from_list( instructions=test_config["instructions_list"]) assert res == test_config["instructions_proto_from_list"] -def test_build_instructions_from_str(mock_playbooks, test_config): - pb, _, _ = mock_playbooks +def test_build_instructions_from_str(test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) res = pb.build_instructions_from_string( instructions=test_config["instructions_str"]) assert res == test_config["instructions_proto_from_str"] -def test_parse_steps_simple_list(mock_playbooks): - pb, _, _ = mock_playbooks +def test_parse_steps_simple_list(test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) lines = [ "Step 1", @@ -415,8 +427,8 @@ def test_parse_steps_simple_list(mock_playbooks): assert steps == expected_steps assert next_index == 3 -def test_parse_steps_nested_list(mock_playbooks, test_config): - pb, _, _ = mock_playbooks +def test_parse_steps_nested_list(test_config): + pb = Playbooks(agent_id=test_config["agent_id"]) lines = [ "- Step 1", @@ -434,8 +446,8 @@ def test_parse_steps_nested_list(mock_playbooks, test_config): assert next_index == 8 def test_create_playbook_version_no_description( - mock_playbooks, test_config, mock_playbook_version_obj_no_description): - pb, mock_client, _ = mock_playbooks + mock_client, test_config, mock_playbook_version_obj_no_description): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.create_playbook_version.return_value = mock_playbook_version_obj_no_description @@ -447,8 +459,8 @@ def test_create_playbook_version_no_description( assert res.description == "" def test_create_playbook_version_with_description( - mock_playbooks, test_config, mock_playbook_version_obj_with_description): - pb, mock_client, _ = mock_playbooks + mock_client, test_config, mock_playbook_version_obj_with_description): + pb = Playbooks(agent_id=test_config["agent_id"]) mock_client.return_value.create_playbook_version.return_value = mock_playbook_version_obj_with_description diff --git a/tests/dfcx_scrapi/core/test_scrapi_base.py b/tests/dfcx_scrapi/core/test_scrapi_base.py new file mode 100644 index 00000000..53d39532 --- /dev/null +++ b/tests/dfcx_scrapi/core/test_scrapi_base.py @@ -0,0 +1,473 @@ +"""Test Class for Base Class Methods in SCRAPI.""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import patch, MagicMock +from google.oauth2.credentials import Credentials as UserCredentials +from google.oauth2.service_account import Credentials as ServiceAccountCredentials +from google.api_core import exceptions +from google.protobuf import field_mask_pb2, struct_pb2 +from google.cloud.dialogflowcx_v3beta1 import types + +from dfcx_scrapi.core.scrapi_base import ( + api_call_counter_decorator, + should_retry, + retry_api_call, + handle_api_error, + ScrapiBase + ) + +@pytest.fixture +def test_config(): + project_id = "my-project-id-1234" + default_id = "00000000-0000-0000-0000-000000000000" + + global_parent = f"projects/{project_id}/locations/global" + global_agent_id = f"{global_parent}/agents/my-agent-1234" + global_datastore_id = f"{global_parent}/dataStores/test-datastore" + global_flow_id = f"{global_agent_id}/flows/{default_id}" + + non_global_parent = f"projects/{project_id}/locations/us-central1" + non_global_agent_id = f"{non_global_parent}/agents/my-agent-1234" + non_global_datastore_id = f"{non_global_parent}/dataStores/test-datastore" + + email = "mock_email@testing.com" + creds_path = "/Users/path/to/creds/credentials.json" + creds_dict = { + "type": "service_account", + "project_id": project_id, + "private_key_id": "1234", + "private_key": "mock_key", + "client_email": f"mock-account@{project_id}.iam.gserviceaccount.com", + "client_id": "1234", + "universe_domain": "googleapis.com", + } + global_scopes = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/dialogflow", + ] + + mock_signer = MagicMock() + mock_signer.key_id = "mock_key_id" + mock_signer.sign.return_value = b"mock_signature" + + creds_object = ServiceAccountCredentials( + signer=mock_signer, + token_uri="https://oauth2.googleapis.com/token", + service_account_email=email, + project_id=project_id, + quota_project_id=project_id, + scopes=[], + ) + creds_object.token = "mock_token" + creds_object.refresh = MagicMock() + + adc_creds = UserCredentials( + token="mock_user_token", + client_id="mock_client_id", + client_secret="mock_client_secret", + quota_project_id=project_id + ) + adc_creds.refresh = MagicMock() + + return { + "project_id": project_id, + "global_scopes": global_scopes, + "global_parent": global_parent, + "global_agent_id": global_agent_id, + "global_flow_id": global_flow_id, + "global_datastore_id": global_datastore_id, + "creds_path": creds_path, + "creds_dict": creds_dict, + "creds_object": creds_object, + "adc_creds_object": adc_creds, + "non_global_agent_id": non_global_agent_id, + "non_global_datastore_id": non_global_datastore_id, + } + +@pytest.fixture +def mocked_scrapi_base(test_config): + """Fixture to provide a ScrapiBase instance with mocked default creds.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default: + mock_adc_creds = test_config["adc_creds_object"] + mock_default.return_value = (mock_adc_creds, "mock_project_id") + yield ScrapiBase() + + +@patch("dfcx_scrapi.core.scrapi_base.default") +def test_init_no_creds(mock_default, test_config): + """Test initialization with no credentials provided.""" + mock_adc_creds = test_config["adc_creds_object"] + mock_default.return_value = (mock_adc_creds, "mock_project_id") + + scrapi_base = ScrapiBase() + + assert scrapi_base.creds == mock_adc_creds + assert isinstance(scrapi_base.creds, UserCredentials) + assert scrapi_base.token == mock_adc_creds.token + assert scrapi_base.agent_id is None + assert not scrapi_base.creds.requires_scopes + mock_adc_creds.refresh.assert_called_once() + mock_default.assert_called_once() + +def test_init_with_creds(test_config): + """Test initialization with user provided credentials.""" + mock_creds = test_config["creds_object"] + + scrapi_base = ScrapiBase(creds=mock_creds) + + assert scrapi_base.creds == mock_creds + assert isinstance(scrapi_base.creds, ServiceAccountCredentials) + assert scrapi_base.token == mock_creds.token + assert scrapi_base.agent_id is None + assert scrapi_base.creds.requires_scopes + mock_creds.refresh.assert_called_once() + + +@patch('google.oauth2.service_account.Credentials.from_service_account_file') +def test_init_with_creds_path(mock_from_service_account_file, test_config): + """Test initialization with credentials path.""" + mock_creds = test_config["creds_object"] + + mock_from_service_account_file.return_value = mock_creds + + + scrapi_base = ScrapiBase(creds_path=test_config["creds_path"]) + + mock_from_service_account_file.assert_called_once_with( + test_config["creds_path"], scopes=test_config["global_scopes"]) + assert scrapi_base.creds == mock_creds + assert isinstance(scrapi_base.creds, ServiceAccountCredentials) + assert scrapi_base.token == mock_creds.token + assert scrapi_base.agent_id is None + assert scrapi_base.creds.requires_scopes + mock_creds.refresh.assert_called_once() + +@patch('google.oauth2.service_account.Credentials.from_service_account_info') +def test_init_with_creds_dict(mock_from_service_account_info, test_config): + """Test initialization with credentials dictionary.""" + mock_creds = test_config["creds_object"] + mock_from_service_account_info.return_value = mock_creds + + scrapi_base = ScrapiBase(creds_dict=test_config["creds_dict"]) + + mock_from_service_account_info.assert_called_once_with( + test_config["creds_dict"], scopes=test_config["global_scopes"]) + assert scrapi_base.creds == mock_creds + assert isinstance(scrapi_base.creds, ServiceAccountCredentials) + assert scrapi_base.token == mock_creds.token + assert scrapi_base.agent_id is None + assert scrapi_base.creds.requires_scopes + mock_creds.refresh.assert_called_once() + +def test_set_region_non_global(test_config): + """Test _set_region with a non-global location.""" + client_options = ScrapiBase._set_region(test_config["non_global_agent_id"]) + assert client_options["api_endpoint"] == "us-central1-dialogflow.googleapis.com:443" + assert client_options["quota_project_id"] == test_config["project_id"] + +def test_set_region_global(test_config): + """Test _set_region with a global location.""" + client_options = ScrapiBase._set_region(test_config["global_agent_id"]) + assert client_options["api_endpoint"] == "dialogflow.googleapis.com:443" + assert client_options["quota_project_id"] == test_config["project_id"] + +def test_set_region_invalid_resource_id(): + """Test _set_region with an invalid resource ID.""" + resource_id = "invalid-resource-id" + with pytest.raises(IndexError): + ScrapiBase._set_region(resource_id) + +def test_client_options_discovery_engine_non_global(test_config): + """Test _client_options_discovery_engine with a non-global location.""" + client_options = ScrapiBase._client_options_discovery_engine( + test_config["non_global_datastore_id"]) + + assert client_options["api_endpoint"] == "us-central1-discoveryengine.googleapis.com:443" + assert client_options["quota_project_id"] == test_config["project_id"] + +def test_client_options_discovery_engine_global(test_config): + """Test _client_options_discovery_engine with a global location.""" + client_options = ScrapiBase._client_options_discovery_engine( + test_config["global_datastore_id"]) + + assert client_options["api_endpoint"] == "discoveryengine.googleapis.com:443" + assert client_options["quota_project_id"] == test_config["project_id"] + +def test_client_options_discovery_engine_invalid_resource_id(): + """Test _client_options_discovery_engine with an invalid resource ID.""" + resource_id = "invalid-resource-id" + with pytest.raises(IndexError): + ScrapiBase._client_options_discovery_engine(resource_id) + +def test_pbuf_to_dict(): + """Test pbuf_to_dict.""" + # Create a sample protobuf message + message = struct_pb2.Struct() + message["field1"] = "value1" + message["field2"] = 123 + + # Convert to dictionary + result = ScrapiBase.pbuf_to_dict(message) + + # Assert the result + assert isinstance(result, dict) + assert result["field1"] == "value1" + assert result["field2"] == 123 + +def test_dict_to_struct(): + """Test dict_to_struct.""" + input_dict = {"field1": "value1", "field2": 123} + result = ScrapiBase.dict_to_struct(input_dict) + assert isinstance(result, struct_pb2.Struct) + assert result["field1"] == "value1" + assert result["field2"] == 123 + +def test_parse_agent_id(test_config): + """Test parse_agent_id with a valid resource ID.""" + resource_id = test_config["global_flow_id"] + agent_id = ScrapiBase.parse_agent_id(resource_id) + assert agent_id == test_config["global_agent_id"] + +def test_parse_agent_id_short_resource_id(test_config): + """Test parse_agent_id with a short resource ID.""" + with pytest.raises(ValueError): + ScrapiBase.parse_agent_id(test_config["global_parent"]) + +def test_parse_agent_id_invalid_resource_id(test_config): + """Test parse_agent_id with an invalid resource ID.""" + with pytest.raises(ValueError): + ScrapiBase.parse_agent_id(test_config["global_datastore_id"]) + +@patch('dfcx_scrapi.core.scrapi_base.api_call_counter_decorator') +def test_api_call_counter_decorator(mock_decorator): + """Test api_call_counter_decorator.""" + mock_decorator.side_effect = lambda func: func # noop + + def mock_api_call(): + pass + + decorated_func = api_call_counter_decorator(mock_api_call) + assert hasattr(decorated_func, "calls_api") + assert decorated_func.calls_api is True + +def test_should_retry(): + """Test should_retry with different exception types.""" + assert should_retry(exceptions.TooManyRequests("Too many requests")) is True + assert should_retry(exceptions.ServerError("Server Error")) is True + assert should_retry(exceptions.BadRequest("Bad Request")) is False + assert should_retry(ValueError("Value error")) is False + +@patch('time.sleep') +def test_retry_api_call_success(mock_sleep): + """Test retry_api_call with a successful API call.""" + + @retry_api_call([1, 2]) + def mock_api_call(): + return "success" + + result = mock_api_call() + assert result == "success" + mock_sleep.assert_not_called() + +@patch('time.sleep') +def test_retry_api_call_too_many_requests(mock_sleep): + """Test retry_api_call with TooManyRequests exception.""" + + @retry_api_call([1, 2]) + def mock_api_call(): + raise exceptions.TooManyRequests("Too many requests") + + with pytest.raises(exceptions.TooManyRequests): + mock_api_call() + + mock_sleep.assert_called_with(2) # Second retry interval + +@patch('time.sleep') +def test_retry_api_call_server_error(mock_sleep): + """Test retry_api_call with ServerError exception.""" + + @retry_api_call([1, 2]) + def mock_api_call(): + raise exceptions.ServerError("Server error") + + with pytest.raises(exceptions.ServerError): + mock_api_call() + + mock_sleep.assert_called_with(2) # Second retry interval + +@patch('time.sleep') +def test_retry_api_call_bad_request(mock_sleep): + """Test retry_api_call with BadRequest exception.""" + + @retry_api_call([1, 2]) + def mock_api_call(): + raise exceptions.BadRequest("Bad request") + + with pytest.raises(exceptions.BadRequest): + mock_api_call() + + mock_sleep.assert_not_called() # No retries for BadRequest + +def test_handle_api_error_success(): + """Test handle_api_error with a successful API call.""" + + @handle_api_error + def mock_api_call(): + return "success" + + result = mock_api_call() + assert result == "success" + +def test_handle_api_error_google_api_call_error(): + """Test handle_api_error with GoogleAPICallError exception.""" + + @handle_api_error + def mock_api_call(): + raise exceptions.GoogleAPICallError("API error") + + result = mock_api_call() + assert result is None + + +def test_handle_api_error_value_error(): + """Test handle_api_error with ValueError exception.""" + + @handle_api_error + def mock_api_call(): + raise ValueError("Value error") + + with pytest.raises(ValueError): + mock_api_call() + +def test_update_kwargs_with_kwargs(): + """Test _update_kwargs with kwargs provided.""" + environment = types.Environment() + kwargs = {"display_name": "New Display Name", "description": "Updated Description"} + field_mask = ScrapiBase._update_kwargs(environment, **kwargs) + + assert environment.display_name == "New Display Name" + assert environment.description == "Updated Description" + assert field_mask == field_mask_pb2.FieldMask(paths=["display_name", "description"]) + +def test_update_kwargs_no_kwargs(): + """Test _update_kwargs with no kwargs provided.""" + environment = types.Environment() + field_mask = ScrapiBase._update_kwargs(environment) + + # Assert that the field mask includes all expected paths for Environment + expected_paths = [ + "name", "display_name", "description", "version_configs", + "update_time", "test_cases_config", "webhook_config", + ] + assert field_mask == field_mask_pb2.FieldMask(paths=expected_paths) + +def test_update_kwargs_experiment(): + """Test _update_kwargs with an Experiment object.""" + experiment = types.Experiment() + field_mask = ScrapiBase._update_kwargs(experiment) + + # Assert that the field mask includes all expected paths for Experiment + expected_paths = [ + "name", "display_name", "description", "state", "definition", + "rollout_config", "rollout_state", "rollout_failure_reason", + "result", "create_time", "start_time", "end_time", + "last_update_time", "experiment_length", "variants_history", + ] + assert field_mask == field_mask_pb2.FieldMask(paths=expected_paths) + +def test_update_kwargs_test_case(): + """Test _update_kwargs with a TestCase object.""" + test_case = types.TestCase() + field_mask = ScrapiBase._update_kwargs(test_case) + + # Assert that the field mask includes all expected paths for TestCase + expected_paths = [ + "name", "tags", "display_name", "notes", "test_config", + "test_case_conversation_turns", "creation_time", + "last_test_result", + ] + assert field_mask == field_mask_pb2.FieldMask(paths=expected_paths) + +def test_update_kwargs_version(): + """Test _update_kwargs with a Version object.""" + version = types.Version() + field_mask = ScrapiBase._update_kwargs(version) + + # Assert that the field mask includes all expected paths for Version + expected_paths = [ + "name", "display_name", "description", "nlu_settings", + "create_time", "state", + ] + assert field_mask == field_mask_pb2.FieldMask(paths=expected_paths) + +def test_update_kwargs_invalid_object(): + """Test _update_kwargs with an invalid object type.""" + with pytest.raises(ValueError) as err: + ScrapiBase._update_kwargs("invalid_object") + + assert str(err.value) == ( + "`obj` should be one of the following: " + "[Environment, Experiment, TestCase, Version]." + ) + +def test_get_api_calls_details_no_calls(mocked_scrapi_base): + """Test get_api_calls_details with no API calls made.""" + api_calls_details = mocked_scrapi_base.get_api_calls_details() + + # Assert that the dictionary is empty, as no API calls have been made + assert api_calls_details == {} + +def test_get_api_calls_details_with_calls(mocked_scrapi_base): + """Test get_api_calls_details with API calls made.""" + @api_call_counter_decorator + def mock_api_call(self): + pass + + # Simulate API calls + mock_api_call(mocked_scrapi_base) # Call once + mock_api_call(mocked_scrapi_base) # Call again + + api_calls_details = mocked_scrapi_base.get_api_calls_details() + + # Assert that the dictionary contains the correct counts + assert api_calls_details["mock_api_call"] == 2 + +def test_get_api_calls_count_no_calls(mocked_scrapi_base): + """Test get_api_calls_count with no API calls made.""" + api_calls_count = mocked_scrapi_base.get_api_calls_count() + + # Assert that the count is 0, as no API calls have been made + assert api_calls_count == 0 + +@patch("dfcx_scrapi.core.scrapi_base.default") +def test_get_api_calls_count_with_calls(mocked_scrapi_base): + """Test get_api_calls_count with API calls made.""" + with patch.object(mocked_scrapi_base, "get_api_calls_count") as mock_count: + mock_count.return_value = 2 + + @api_call_counter_decorator + def mock_api_call(self): + pass + + # Simulate API calls + mock_api_call(mocked_scrapi_base) # Call once + mock_api_call(mocked_scrapi_base) # Call again + + api_calls_count = mocked_scrapi_base.get_api_calls_count() + + # Assert that the count is correct + assert api_calls_count == 2 diff --git a/tests/dfcx_scrapi/core/test_test_cases.py b/tests/dfcx_scrapi/core/test_test_cases.py index 751d9cc2..1fbcb81d 100644 --- a/tests/dfcx_scrapi/core/test_test_cases.py +++ b/tests/dfcx_scrapi/core/test_test_cases.py @@ -19,7 +19,7 @@ import pytest import pandas as pd -from unittest.mock import patch +from unittest.mock import patch, MagicMock from dfcx_scrapi.core.test_cases import TestCases as PyTestCases from google.cloud.dialogflowcx_v3beta1 import types from google.cloud.dialogflowcx_v3beta1.services import test_cases @@ -27,7 +27,10 @@ @pytest.fixture def test_config(): - agent_id = "projects/mock-test/locations/global/agents/a1s2d3f4" + project_id = "my-project-id-1234" + location_id = "global" + parent = f"projects/{project_id}/locations/{location_id}" + agent_id = f"{parent}/agents/my-agent-1234" flow_id = f"{agent_id}/flows/00000000-0000-0000-0000-000000000000" page_id = f"{flow_id}/pages/mock-page-1234" other_flow_id = f"{agent_id}/flows/other1234" @@ -42,6 +45,7 @@ def test_config(): } return { + "project_id": project_id, "agent_id": agent_id, "flow_id": flow_id, "page_id": page_id, @@ -156,8 +160,19 @@ def mock_list_tc_pager_no_turns(mock_tc_obj_no_turns): ), ) +@pytest.fixture(autouse=True) +def mock_client(test_config): + """Fixture to create a mocked TestCasesClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") as mock_client: + + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client -# Private Methods def test_convert_test_result_to_string(mock_tc_obj_turns): tests = [(0, "TEST_RESULT_UNSPECIFIED"), (1, "PASSED"), (2, "FAILED")] tc = PyTestCases() @@ -311,10 +326,8 @@ def test_retest_cases(mock_batch_run, mock_tc_df): # List Test Cases -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_list_test_cases_agent_id_not_in_instance( - mock_client, mock_tc_obj_turns -): + mock_client, mock_tc_obj_turns): mock_client.return_value.list_test_cases.return_value = [mock_tc_obj_turns] tc = PyTestCases() @@ -323,7 +336,6 @@ def test_list_test_cases_agent_id_not_in_instance( _ = tc.list_test_cases() -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_list_test_cases_agent_id_in_instance( mock_client, mock_list_tc_pager, test_config ): @@ -336,7 +348,6 @@ def test_list_test_cases_agent_id_in_instance( assert isinstance(res[0], types.TestCase) -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_list_test_cases_agent_id_in_method( mock_client, mock_list_tc_pager_no_turns, test_config ): @@ -352,7 +363,6 @@ def test_list_test_cases_agent_id_in_method( assert res[0].test_case_conversation_turns == "" -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_list_test_cases_include_conversation_turns( mock_client, mock_list_tc_pager, test_config ): @@ -367,11 +377,6 @@ def test_list_test_cases_include_conversation_turns( assert isinstance(res[0], types.TestCase) assert res[0].test_case_conversation_turns != "" - -# Update Test Cases - - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_no_args(mock_client, mock_tc_obj_turns): mock_client.return_value.update_test_case.return_value = mock_tc_obj_turns @@ -381,8 +386,6 @@ def test_update_test_case_no_args(mock_client, mock_tc_obj_turns): with pytest.raises(ValueError): _ = tc.update_test_case() - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_kwargs_only(mock_client, mock_tc_obj_turns): mock_client.return_value.update_test_case.return_value = mock_tc_obj_turns @@ -392,8 +395,6 @@ def test_update_test_case_kwargs_only(mock_client, mock_tc_obj_turns): with pytest.raises(ValueError): _ = tc.update_test_case(display_name="mock test case object update") - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_id_only(mock_client, test_config, mock_tc_obj_turns): mock_client.return_value.update_test_case.return_value = mock_tc_obj_turns @@ -404,8 +405,6 @@ def test_update_test_case_id_only(mock_client, test_config, mock_tc_obj_turns): with pytest.raises(ValueError): _ = tc.update_test_case(test_case_id=test_config["test_case_id"]) - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_obj_only(mock_client, mock_tc_obj_turns): mock_client.return_value.update_test_case.return_value = mock_tc_obj_turns @@ -416,8 +415,6 @@ def test_update_test_case_obj_only(mock_client, mock_tc_obj_turns): assert result.display_name == "mock test case object" assert result == mock_tc_obj_turns - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_obj_only_empty_name(mock_client, mock_tc_obj_turns): mock_tc_obj_turns.name = "" @@ -429,8 +426,6 @@ def test_update_test_case_obj_only_empty_name(mock_client, mock_tc_obj_turns): with pytest.raises(ValueError): _ = tc.update_test_case(obj=mock_tc_obj_turns) - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_with_obj_and_kwargs( mock_client, mock_tc_obj_turns, mock_updated_tc_obj ): @@ -446,8 +441,6 @@ def test_update_test_case_with_obj_and_kwargs( assert result.display_name == mock_updated_tc_obj.display_name assert result == mock_updated_tc_obj - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_with_id_and_kwargs( mock_client, test_config, mock_tc_obj_turns ): @@ -465,8 +458,6 @@ def test_update_test_case_with_id_and_kwargs( # Assertions assert result.display_name == "mock test case object update" - -@patch("dfcx_scrapi.core.test_cases.services.test_cases.TestCasesClient") def test_update_test_case_with_obj_id_and_kwargs( mock_client, test_config, mock_tc_obj_turns ): diff --git a/tests/dfcx_scrapi/core/test_tools.py b/tests/dfcx_scrapi/core/test_tools.py index 48ec4377..2a268a48 100644 --- a/tests/dfcx_scrapi/core/test_tools.py +++ b/tests/dfcx_scrapi/core/test_tools.py @@ -19,14 +19,17 @@ # limitations under the License. import pytest -from unittest.mock import patch +from unittest.mock import patch, MagicMock from dfcx_scrapi.core.tools import Tools from google.cloud.dialogflowcx_v3beta1 import types from google.cloud.dialogflowcx_v3beta1 import services @pytest.fixture def test_config(): - agent_id = "projects/mock-test/locations/global/agents/a1s2d3f4" + project_id = "my-project-id-1234" + location_id = "global" + parent = f"projects/{project_id}/locations/{location_id}" + agent_id = f"{parent}/agents/my-agent-1234" tool_id = f"{agent_id}/tools/1234" display_name = "mock tool" description = "This is a mock tool." @@ -69,6 +72,7 @@ def test_config(): """ return { + "project_id": project_id, "agent_id": agent_id, "tool_id": tool_id, "display_name": display_name, @@ -117,8 +121,20 @@ def mock_list_playbooks_pager(mock_playbook_obj): types.playbook.ListPlaybooksResponse(playbooks=[mock_playbook_obj]), ) +@pytest.fixture(autouse=True) +def mock_client(test_config): + """Fixture to create a mocked ToolsClient.""" + with patch("dfcx_scrapi.core.scrapi_base.default") as mock_default, \ + patch("dfcx_scrapi.core.scrapi_base.Request") as mock_request, \ + patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") as mock_client: + + mock_creds = MagicMock() + mock_default.return_value = (mock_creds, test_config["project_id"]) + mock_request.return_value = MagicMock() + + yield mock_client + # Test get_tools_map -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_get_tools_map(mock_client, mock_list_tools_pager, test_config): mock_client.return_value.list_tools.return_value = mock_list_tools_pager tools = Tools(agent_id=test_config["agent_id"]) @@ -129,7 +145,6 @@ def test_get_tools_map(mock_client, mock_list_tools_pager, test_config): assert res[test_config["tool_id"]] == test_config["display_name"] # Test get_tools_map (reversed) -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_get_tools_map_reversed( mock_client, mock_list_tools_pager, test_config): mock_client.return_value.list_tools.return_value = mock_list_tools_pager @@ -141,7 +156,6 @@ def test_get_tools_map_reversed( assert res[test_config["display_name"]] == test_config["tool_id"] # Test list_tools -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_list_tools(mock_client, mock_list_tools_pager, test_config): mock_client.return_value.list_tools.return_value = mock_list_tools_pager tools = Tools(agent_id=test_config["agent_id"]) @@ -151,7 +165,6 @@ def test_list_tools(mock_client, mock_list_tools_pager, test_config): assert isinstance(res[0], types.Tool) # Test get_tool -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_get_tool(mock_client, mock_tool_obj, test_config): mock_client.return_value.get_tool.return_value = mock_tool_obj tools = Tools(agent_id=test_config["agent_id"]) @@ -161,7 +174,6 @@ def test_get_tool(mock_client, mock_tool_obj, test_config): assert res.display_name == test_config["display_name"] # Test create_tool -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_create_tool_from_kwargs( mock_client, mock_tool_obj, test_config): mock_client.return_value.create_tool.return_value = mock_tool_obj @@ -173,7 +185,6 @@ def test_create_tool_from_kwargs( assert isinstance(res, types.Tool) assert res.display_name == test_config["display_name"] -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_create_tool_from_proto_object( mock_client, mock_tool_obj, test_config): mock_client.return_value.create_tool.return_value = mock_tool_obj @@ -186,7 +197,6 @@ def test_create_tool_from_proto_object( assert res.display_name == test_config["display_name"] # Test delete_tool with tool_id -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_delete_tool_with_tool_id(mock_client, test_config): tools = Tools(agent_id=test_config["agent_id"]) tools.delete_tool(tool_id=test_config["tool_id"]) @@ -195,7 +205,6 @@ def test_delete_tool_with_tool_id(mock_client, test_config): ) # Test delete_tool with obj -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_delete_tool_with_obj(mock_client, mock_tool_obj, test_config): tools = Tools(agent_id=test_config["agent_id"]) tools.delete_tool(obj=mock_tool_obj) @@ -204,7 +213,6 @@ def test_delete_tool_with_obj(mock_client, mock_tool_obj, test_config): ) # Test update_tool with kwargs -@patch("dfcx_scrapi.core.tools.services.tools.ToolsClient") def test_update_tool_with_kwargs( mock_client, mock_tool_obj_updated, test_config): mock_client.return_value.update_tool.return_value = mock_tool_obj_updated