Skip to content

Commit

Permalink
Add update_keywords and refactor update_test_casse
Browse files Browse the repository at this point in the history
  • Loading branch information
MRyderOC committed Feb 2, 2024
1 parent c812b41 commit fa372ea
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
66 changes: 46 additions & 20 deletions src/dfcx_scrapi/core/scrapi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from collections import defaultdict
from typing import Dict

from google.cloud.dialogflowcx_v3beta1 import types
from google.oauth2 import service_account
from google.auth.transport.requests import Request
from google.protobuf import json_format # type: ignore
Expand Down Expand Up @@ -334,26 +335,6 @@ def _get_solution_type(solution_type: str) -> int:

return solution_map[solution_type]

@staticmethod
def _update_kwargs(
resource_type: str, obj, kwargs: Dict) -> field_mask_pb2.FieldMask: # pylint: disable=W0613
"""Create a `mask` for update methods based on kwargs.
Args:
resource_type (str): The type of the input resource as a string.
obj: The protobuf object.
kwargs (Dict): A dictionary with obj's attributes as keys and
thier values as values.
Returns:
A FieldMask protobuf.
"""
for key, value in kwargs.items():
setattr(obj, key, value)

return field_mask_pb2.FieldMask(paths=kwargs.keys())


def _build_data_store_parent(self, location: str) -> str:
"""Build the Parent ID needed for Discovery Engine API calls."""
return (f"projects/{self.project_id}/locations/{location}/collections/"
Expand Down Expand Up @@ -417,6 +398,51 @@ def get_api_calls_count(self) -> int:
return sum(self.get_api_calls_details().values())


@staticmethod
def _update_kwargs(obj, **kwargs) -> field_mask_pb2.FieldMask:
"""Create a FieldMask for Environment, Experiment, TestCase, Version."""
if kwargs:
for key, value in kwargs.items():
setattr(obj, key, value)
return field_mask_pb2.FieldMask(paths=kwargs.keys())
attrs_map = {
"Environment": [
"name", "display_name", "description", "version_configs",
"update_time", "test_cases_config", "webhook_config",
],
"Experiment": [
"name", "display_name", "description", "state", "definition",
"rollout_config", "rollout_state", "rollout_failure_reason",
"result", "create_time", "start_time", "end_time",
"last_update_time", "experiment_length", "variants_history",
],
"TestCase": [
"name", "tags", "display_name", "notes", "test_config",
"test_case_conversation_turns", "creation_time",
"last_test_result",
],
"Version": [
"name", "display_name", "description", "nlu_settings",
"create_time", "state",
],
}
if isinstance(obj, types.Environment):
paths = attrs_map["Environment"]
elif isinstance(obj, types.Experiment):
paths = attrs_map["Experiment"]
elif isinstance(obj, types.TestCase):
paths = attrs_map["TestCase"]
elif isinstance(obj, types.Version):
paths = attrs_map["Version"]
else:
raise ValueError(
"`obj` should be one of the following:"
" [Environment, Experiment, TestCase, Version]."
)

return field_mask_pb2.FieldMask(paths=paths)


def api_call_counter_decorator(func):
"""Counts the number of API calls for the function `func`."""

Expand Down
10 changes: 5 additions & 5 deletions src/dfcx_scrapi/core/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,15 @@ def update_test_case(
if obj:
test_case = obj
test_case.name = test_case_id
else:
mask = self._update_kwargs(obj)
elif kwargs:
if not test_case_id:
test_case_id = self.test_case_id
test_case = self.get_test_case(test_case_id)
mask = self._update_kwargs(obj, **kwargs)

request = types.test_case.UpdateTestCaseRequest()
request.test_case = test_case
if kwargs:
request.update_mask = self._update_kwargs("test_case", obj, kwargs)
request = types.test_case.UpdateTestCaseRequest(
test_case=test_case, update_mask=mask)

client_options = self._set_region(test_case_id)
client = services.test_cases.TestCasesClient(
Expand Down

0 comments on commit fa372ea

Please sign in to comment.