Skip to content

Commit

Permalink
Merge pull request #171 from GoogleCloudPlatform/bug/update_type
Browse files Browse the repository at this point in the history
Bug/update type
  • Loading branch information
kmaphoenix authored Feb 8, 2024
2 parents 456f343 + 73b99d2 commit 7679e03
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 27 deletions.
25 changes: 8 additions & 17 deletions src/dfcx_scrapi/core/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from google.oauth2 import service_account
from google.cloud.dialogflowcx_v3beta1 import services
from google.cloud.dialogflowcx_v3beta1 import types
from google.protobuf import field_mask_pb2

from dfcx_scrapi.core import scrapi_base
from dfcx_scrapi.core import flows
from dfcx_scrapi.core import versions
Expand Down Expand Up @@ -306,31 +306,22 @@ def update_environment(
Returns:
An object representing a long-running operation. (LRO)
"""

if environment_obj:
env = environment_obj
else:
env = types.Environment()

env.name = environment_id
env.name = environment_id
mask = self._update_kwargs(environment_obj)
elif kwargs:
env = self.get_environment(environment_id)
mask = self._update_kwargs(environment_obj, **kwargs)

# set environment attributes from kwargs
for key, value in kwargs.items():
setattr(env, key, value)
paths = kwargs.keys()
mask = field_mask_pb2.FieldMask(paths=paths)
request = types.environment.UpdateEnvironmentRequest(
environment=env, update_mask=mask)

client_options = self._set_region(environment_id)
client = services.environments.EnvironmentsClient(
credentials=self.creds, client_options=client_options
)

request = types.environment.UpdateEnvironmentRequest()
request.environment = env
request.update_mask = mask

response = client.update_environment(request)

return response


Expand Down
47 changes: 47 additions & 0 deletions src/dfcx_scrapi/core/scrapi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from collections import defaultdict
from typing import Dict

from google.cloud.dialogflowcx_v3beta1 import types
from google.oauth2 import service_account
from google.auth.transport.requests import Request
from google.protobuf import json_format # type: ignore
from google.protobuf import field_mask_pb2

from proto.marshal.collections import repeated
from proto.marshal.collections import maps
Expand Down Expand Up @@ -396,6 +398,51 @@ def get_api_calls_count(self) -> int:
return sum(self.get_api_calls_details().values())


@staticmethod
def _update_kwargs(obj, **kwargs) -> field_mask_pb2.FieldMask:
"""Create a FieldMask for Environment, Experiment, TestCase, Version."""
if kwargs:
for key, value in kwargs.items():
setattr(obj, key, value)
return field_mask_pb2.FieldMask(paths=kwargs.keys())
attrs_map = {
"Environment": [
"name", "display_name", "description", "version_configs",
"update_time", "test_cases_config", "webhook_config",
],
"Experiment": [
"name", "display_name", "description", "state", "definition",
"rollout_config", "rollout_state", "rollout_failure_reason",
"result", "create_time", "start_time", "end_time",
"last_update_time", "experiment_length", "variants_history",
],
"TestCase": [
"name", "tags", "display_name", "notes", "test_config",
"test_case_conversation_turns", "creation_time",
"last_test_result",
],
"Version": [
"name", "display_name", "description", "nlu_settings",
"create_time", "state",
],
}
if isinstance(obj, types.Environment):
paths = attrs_map["Environment"]
elif isinstance(obj, types.Experiment):
paths = attrs_map["Experiment"]
elif isinstance(obj, types.TestCase):
paths = attrs_map["TestCase"]
elif isinstance(obj, types.Version):
paths = attrs_map["Version"]
else:
raise ValueError(
"`obj` should be one of the following:"
" [Environment, Experiment, TestCase, Version]."
)

return field_mask_pb2.FieldMask(paths=paths)


def api_call_counter_decorator(func):
"""Counts the number of API calls for the function `func`."""

Expand Down
15 changes: 5 additions & 10 deletions src/dfcx_scrapi/core/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from google.cloud.dialogflowcx_v3beta1 import services
from google.cloud.dialogflowcx_v3beta1 import types
from google.protobuf import field_mask_pb2

from dfcx_scrapi.core import scrapi_base
from dfcx_scrapi.core import flows
Expand Down Expand Up @@ -463,19 +462,15 @@ def update_test_case(
if obj:
test_case = obj
test_case.name = test_case_id
else:
mask = self._update_kwargs(obj)
elif kwargs:
if not test_case_id:
test_case_id = self.test_case_id
test_case = self.get_test_case(test_case_id)
mask = self._update_kwargs(obj, **kwargs)

for key, value in kwargs.items():
setattr(test_case, key, value)
paths = kwargs.keys()
mask = field_mask_pb2.FieldMask(paths=paths)

request = types.test_case.UpdateTestCaseRequest()
request.test_case = test_case
request.update_mask = mask
request = types.test_case.UpdateTestCaseRequest(
test_case=test_case, update_mask=mask)

client_options = self._set_region(test_case_id)
client = services.test_cases.TestCasesClient(
Expand Down

0 comments on commit 7679e03

Please sign in to comment.