diff --git a/google/cloud/aiplatform/persistent_resource.py b/google/cloud/aiplatform/persistent_resource.py index 0bd6dbe404..f0944a5bb4 100644 --- a/google/cloud/aiplatform/persistent_resource.py +++ b/google/cloud/aiplatform/persistent_resource.py @@ -420,3 +420,34 @@ def list( location=location, credentials=credentials, ) + + @base.optional_sync() + def reboot( + self, + sync: Optional[bool] = True, # pylint: disable=unused-argument + ) -> None: + """Reboots this Persistent Resource. + + Args: + name (str): + Required. The name of the PersistentResource resource. + Name should be in the following format: + ``projects/{project_id_or_number}/locations/{location_id}/persistentResources/{persistent_resource_id}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + sync (bool): + Whether to execute this method synchonously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + """ + + _LOGGER.log_action_start_against_resource("Rebooting", "", self) + lro = self.api_client.reboot_persistent_resource(name=self.resource_name) + _LOGGER.log_action_started_against_resource_with_lro( + "Reboot", "", self.__class__, lro + ) + lro.result(timeout=None) + _LOGGER.log_action_completed_against_resource("rebooted.", "", self) diff --git a/tests/unit/aiplatform/test_persistent_resource.py b/tests/unit/aiplatform/test_persistent_resource.py index 14421e9066..55c460b113 100644 --- a/tests/unit/aiplatform/test_persistent_resource.py +++ b/tests/unit/aiplatform/test_persistent_resource.py @@ -153,6 +153,20 @@ def delete_persistent_resource_mock(): yield delete_persistent_resource_mock +@pytest.fixture +def reboot_persistent_resource_mock(): + with mock.patch.object( + (persistent_resource_service_client_v1.PersistentResourceServiceClient), + "reboot_persistent_resource", + ) as reboot_persistent_resource_mock: + reboot_lro = mock.Mock(ga_operation.Operation) + reboot_lro.result.return_value = ( + persistent_resource_service_v1.RebootPersistentResourceRequest() + ) + reboot_persistent_resource_mock.return_value = reboot_lro + yield reboot_persistent_resource_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestPersistentResource: def setup_method(self): @@ -359,3 +373,23 @@ def test_delete_persistent_resource( delete_persistent_resource_mock.assert_called_once_with( name=_TEST_PERSISTENT_RESOURCE_ID, ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_reboot_persistent_resource( + self, + get_persistent_resource_mock, + reboot_persistent_resource_mock, + sync, + ): + test_resource = persistent_resource.PersistentResource( + _TEST_PERSISTENT_RESOURCE_ID + ) + test_resource.reboot(sync=sync) + + if not sync: + test_resource.wait() + + get_persistent_resource_mock.assert_called_once() + reboot_persistent_resource_mock.assert_called_once_with( + name=_TEST_PERSISTENT_RESOURCE_ID, + )