diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 2aa98b1600..bb4b165d31 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -32,6 +32,7 @@ from google.cloud.aiplatform.constants import base as constants from google.cloud.aiplatform import utils from google.cloud.aiplatform.metadata import metadata +from google.cloud.aiplatform.utils import resource_manager_utils from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec_compat, @@ -149,6 +150,25 @@ def project(self) -> str: if self._project: return self._project + # Project is not set. Trying to get it from the environment. + # See https://github.com/googleapis/python-aiplatform/issues/852 + # See https://github.com/googleapis/google-auth-library-python/issues/924 + # TODO: Remove when google.auth.default() learns the + # CLOUD_ML_PROJECT_ID env variable or Vertex AI starts setting GOOGLE_CLOUD_PROJECT env variable. + project_number = os.environ.get("CLOUD_ML_PROJECT_ID") + if project_number: + # Try to convert project number to project ID which is more readable. + try: + project_id = resource_manager_utils.get_project_id( + project_number=project_number, credentials=self.credentials, + ) + return project_id + except Exception: + logging.getLogger(__name__).warning( + "Failed to convert project number to project ID.", exc_info=True + ) + return project_number + project_not_found_exception_str = ( "Unable to find your project. Please provide a project ID by:" "\n- Passing a constructor argument" diff --git a/tests/system/aiplatform/test_project_id_inference.py b/tests/system/aiplatform/test_project_id_inference.py new file mode 100644 index 0000000000..dc047eb350 --- /dev/null +++ b/tests/system/aiplatform/test_project_id_inference.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from google.cloud import aiplatform +from google.cloud.aiplatform.compat.types import pipeline_state as gca_pipeline_state +from tests.system.aiplatform import e2e_base + + +@pytest.mark.usefixtures("prepare_staging_bucket", "delete_staging_bucket") +class TestProjectIDInference(e2e_base.TestEndToEnd): + + _temp_prefix = "temp-vertex-sdk-project-id-inference" + + def test_project_id_inference(self, shared_state): + # Collection of resources generated by this test, to be deleted during teardown + shared_state["resources"] = [] + + aiplatform.init( + location=e2e_base._LOCATION, + staging_bucket=shared_state["staging_bucket_name"], + ) + + worker_pool_specs = [ + { + "machine_spec": {"machine_type": "n1-standard-2"}, + "replica_count": 1, + "container_spec": { + "image_uri": "python:3.9", + "command": [ + "sh", + "-exc", + """python3 -m pip install git+https://github.com/Ark-kun/python-aiplatform@fix--Fixed-getitng-project-ID-when-running-on-Vertex-AI#egg=google-cloud-aiplatform&subdirectory=. + "$0" "$@" + """, + "python3", + "-c", + """ + from google.cloud import aiplatform + # Not initializing the Vertex SDK explicitly + # Checking teh project ID + print(aiplatform.initializer.global_config.project) + assert not aiplatform.initializer.global_config.project.endswith("-tp") + # Testing ability to list resources + endpoints = aiplatform.Endpoint.list() + print(endpoints) + """, + ], + "args": [], + }, + } + ] + + custom_job = aiplatform.CustomJob( + display_name=self._make_display_name("custom"), + worker_pool_specs=worker_pool_specs, + ) + custom_job.run( + enable_web_access=True, sync=False, + ) + + shared_state["resources"].append(custom_job) + + in_progress_done_check = custom_job.done() + custom_job.wait_for_resource_creation() + + completion_done_check = custom_job.done() + + assert ( + custom_job.state + == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + # Check done() method works correctly + assert in_progress_done_check is False + assert completion_done_check is True diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index e52dfef3aa..a85e8257bc 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -28,6 +28,7 @@ from google.cloud.aiplatform.metadata.metadata import metadata_service from google.cloud.aiplatform.constants import base as constants from google.cloud.aiplatform import utils +from google.cloud.aiplatform.utils import resource_manager_utils from google.cloud.aiplatform_v1.services.model_service import ( client as model_service_client, @@ -61,6 +62,22 @@ def mock_auth_default(): monkeypatch.setattr(google.auth, "default", mock_auth_default) assert initializer.global_config.project == _TEST_PROJECT + def test_infer_project_id(self): + cloud_project_number = "123" + + def mock_get_project_id(project_number: str, **_): + assert project_number == cloud_project_number + return _TEST_PROJECT + + with mock.patch.object( + target=resource_manager_utils, + attribute="get_project_id", + new=mock_get_project_id, + ), mock.patch.dict( + os.environ, {"CLOUD_ML_PROJECT_ID": cloud_project_number}, clear=True + ): + assert initializer.global_config.project == _TEST_PROJECT + def test_init_location_sets_location(self): initializer.global_config.init(location=_TEST_LOCATION) assert initializer.global_config.location == _TEST_LOCATION