diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 27cb8234b6..1ab2986f07 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -48,8 +48,7 @@ env_var as gca_env_var_compat, ) -from google.protobuf import json_format - +from google.protobuf import field_mask_pb2, json_format _LOGGER = base.Logger(__name__) @@ -1502,6 +1501,73 @@ def __init__( ) self._gca_resource = self._get_gca_resource(resource_name=model_name) + def update( + self, + display_name: Optional[str] = None, + description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + ) -> "Model": + """Updates a model. + + Example usage: + + my_model = my_model.update( + display_name='my-model', + description='my description', + labels={'key': 'value'}, + ) + + Args: + display_name (str): + The display name of the Model. The name can be up to 128 + characters long and can be consist of any UTF-8 characters. + description (str): + The description of the model. + labels (Dict[str, str]): + Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + Returns: + model: Updated model resource. + Raises: + ValueError: If `labels` is not the correct format. + """ + + current_model_proto = self.gca_resource + copied_model_proto = current_model_proto.__class__(current_model_proto) + + update_mask: List[str] = [] + + if display_name: + utils.validate_display_name(display_name) + + copied_model_proto.display_name = display_name + update_mask.append("display_name") + + if description: + copied_model_proto.description = description + update_mask.append("description") + + if labels: + utils.validate_labels(labels) + + copied_model_proto.labels = labels + update_mask.append("labels") + + update_mask = field_mask_pb2.FieldMask(paths=update_mask) + + self.api_client.update_model(model=copied_model_proto, update_mask=update_mask) + + self._sync_gca_resource() + + return self + # TODO(b/170979552) Add support for predict schemata # TODO(b/170979926) Add support for metadata and metadata schema @classmethod diff --git a/tests/system/aiplatform/test_model_upload.py b/tests/system/aiplatform/test_model_upload.py index 35ae44da69..625a3dc2c4 100644 --- a/tests/system/aiplatform/test_model_upload.py +++ b/tests/system/aiplatform/test_model_upload.py @@ -37,7 +37,7 @@ class TestModel(e2e_base.TestEndToEnd): _temp_prefix = f"{_TEST_PROJECT}-vertex-staging-{_TEST_LOCATION}" def test_upload_and_deploy_xgboost_model(self, shared_state): - """Upload XGBoost model from local file and deploy it for prediction.""" + """Upload XGBoost model from local file and deploy it for prediction. Additionally, update model name, description and labels""" aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) @@ -65,3 +65,12 @@ def test_upload_and_deploy_xgboost_model(self, shared_state): shared_state["resources"].append(endpoint) predict_response = endpoint.predict(instances=[[0, 0, 0]]) assert len(predict_response.predictions) == 1 + + model = model.update( + display_name="new_name", + description="new_description", + labels={"my_label": "updated"}, + ) + assert model.display_name == "new_name" + assert model.display_name == "new_description" + assert model.labels == {"my_label": "updated"} diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index bf87f3593d..f89372481d 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -54,6 +54,7 @@ encryption_spec as gca_encryption_spec, ) +from google.protobuf import field_mask_pb2 from test_endpoints import create_endpoint_mock # noqa: F401 @@ -177,6 +178,27 @@ _TEST_CONTAINER_REGISTRY_DESTINATION +@pytest.fixture +def mock_model(): + model = mock.MagicMock(models.Model) + model.name = _TEST_ID + model._latest_future = None + model._exception = None + model._gca_resource = gca_model.Model( + display_name=_TEST_MODEL_NAME, + description=_TEST_DESCRIPTION, + labels=_TEST_LABEL, + ) + yield model + + +@pytest.fixture +def update_model_mock(mock_model): + with patch.object(model_service_client.ModelServiceClient, "update_model") as mock: + mock.return_value = mock_model + yield mock + + @pytest.fixture def get_endpoint_mock(): with mock.patch.object( @@ -199,6 +221,7 @@ def get_model_mock(): get_model_mock.return_value = gca_model.Model( display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME, ) + yield get_model_mock @@ -1660,3 +1683,29 @@ def test_upload_tensorflow_saved_model_uploads_and_gets_model( ] staged_model_file_name = staged_model_file_path.split("/")[-1] assert staged_model_file_name in ["saved_model.pb", "saved_model.pbtxt"] + + @pytest.mark.usefixtures("get_model_mock") + def test_update(self, update_model_mock, get_model_mock): + + test_model = models.Model(_TEST_ID) + + test_model.update( + display_name=_TEST_MODEL_NAME, + description=_TEST_DESCRIPTION, + labels=_TEST_LABEL, + ) + + current_model_proto = gca_model.Model( + display_name=_TEST_MODEL_NAME, + description=_TEST_DESCRIPTION, + labels=_TEST_LABEL, + name=_TEST_MODEL_RESOURCE_NAME, + ) + + update_mask = field_mask_pb2.FieldMask( + paths=["display_name", "description", "labels"] + ) + + update_model_mock.assert_called_once_with( + model=current_model_proto, update_mask=update_mask + )