Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added aiplatform.Model.update method #952

Merged
merged 12 commits into from
Jan 24, 2022
72 changes: 70 additions & 2 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
env_var as gca_env_var_compat,
)

from google.protobuf import json_format

from google.protobuf import field_mask_pb2, json_format

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -1502,6 +1501,75 @@ def __init__(
)
self._gca_resource = self._get_gca_resource(resource_name=model_name)

def update(
self,
display_name: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
) -> "Model":
"""Updates a model.

Example usage:

my_model = my_model.update(
display_name='my-model',
description='my description',
labels={'key': 'value'},
)

Args:
display_name (str):
The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
description (str):
The description of the model.
labels (Dict[str, str]):
Optional. The labels with user-defined metadata to
organize your Models.
Label keys and values can be no longer than 64
characters (Unicode codepoints), can only
contain lowercase letters, numeric characters,
underscores and dashes. International characters
are allowed.
See https://goo.gl/xmQnxf for more information
and examples of labels.
Returns:
model: Updated model resource.
Raises:
ValueError: If `labels` is not the correct format.
"""

current_model_proto = self.gca_resource
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved
copied_model_proto = current_model_proto.__class__(current_model_proto)

update_mask: List[str] = []

if display_name:
utils.validate_display_name(display_name)

copied_model_proto.display_name = display_name
update_mask.append("display_name")

if description:
copied_model_proto.description = description
update_mask.append("description")

if labels:
utils.validate_labels(labels)

copied_model_proto.labels = labels
update_mask.append("labels")

update_mask = field_mask_pb2.FieldMask(paths=update_mask)

_ = self.api_client.update_model(
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved
model=copied_model_proto, update_mask=update_mask
)

self._sync_gca_resource()

return self
ivanmkc marked this conversation as resolved.
Show resolved Hide resolved

# TODO(b/170979552) Add support for predict schemata
# TODO(b/170979926) Add support for metadata and metadata schema
@classmethod
Expand Down
11 changes: 10 additions & 1 deletion tests/system/aiplatform/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TestModel(e2e_base.TestEndToEnd):
_temp_prefix = f"{_TEST_PROJECT}-vertex-staging-{_TEST_LOCATION}"

def test_upload_and_deploy_xgboost_model(self, shared_state):
"""Upload XGBoost model from local file and deploy it for prediction."""
"""Upload XGBoost model from local file and deploy it for prediction. Additionally, update model name, description and labels"""

aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

Expand Down Expand Up @@ -65,3 +65,12 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
shared_state["resources"].append(endpoint)
predict_response = endpoint.predict(instances=[[0, 0, 0]])
assert len(predict_response.predictions) == 1

model = model.update(
display_name="new_name",
description="new_description",
labels={"my_label": "updated"},
)
assert model.display_name == "new_name"
assert model.display_name == "new_description"
assert model.labels == {"my_label": "updated"}
49 changes: 49 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
encryption_spec as gca_encryption_spec,
)

from google.protobuf import field_mask_pb2

from test_endpoints import create_endpoint_mock # noqa: F401

Expand Down Expand Up @@ -177,6 +178,27 @@
_TEST_CONTAINER_REGISTRY_DESTINATION


@pytest.fixture
def mock_model():
model = mock.MagicMock(models.Model)
model.name = _TEST_ID
model._latest_future = None
model._exception = None
model._gca_resource = gca_model.Model(
display_name=_TEST_MODEL_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABEL,
)
yield model


@pytest.fixture
def update_model_mock(mock_model):
with patch.object(model_service_client.ModelServiceClient, "update_model") as mock:
mock.return_value = mock_model
yield mock


@pytest.fixture
def get_endpoint_mock():
with mock.patch.object(
Expand All @@ -199,6 +221,7 @@ def get_model_mock():
get_model_mock.return_value = gca_model.Model(
display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME,
)

yield get_model_mock


Expand Down Expand Up @@ -1660,3 +1683,29 @@ def test_upload_tensorflow_saved_model_uploads_and_gets_model(
]
staged_model_file_name = staged_model_file_path.split("/")[-1]
assert staged_model_file_name in ["saved_model.pb", "saved_model.pbtxt"]

@pytest.mark.usefixtures("get_model_mock")
def test_update(self, update_model_mock, get_model_mock):

test_model = models.Model(_TEST_ID)

test_model.update(
display_name=_TEST_MODEL_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABEL,
)

current_model_proto = gca_model.Model(
display_name=_TEST_MODEL_NAME,
description=_TEST_DESCRIPTION,
labels=_TEST_LABEL,
name=_TEST_MODEL_RESOURCE_NAME,
)

update_mask = field_mask_pb2.FieldMask(
paths=["display_name", "description", "labels"]
)

update_model_mock.assert_called_once_with(
model=current_model_proto, update_mask=update_mask
)