From 7404f679246e41e0009ec2d49f05d669eb357f71 Mon Sep 17 00:00:00 2001 From: Matthew Tang Date: Fri, 9 Aug 2024 14:20:38 -0700 Subject: [PATCH] feat: Support api keys in initializer and create_client PiperOrigin-RevId: 661401962 --- google/cloud/aiplatform/initializer.py | 28 +++++++++++++++++-- .../generative_models/_generative_models.py | 4 ++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index b437f1f4d1..08c3528c57 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -74,7 +74,7 @@ def _set_project_as_env_var_or_google_auth_default(self): the project and credentials have already been set. """ - if not self._project: + if not self._project and not self._api_key: # Project is not set. Trying to get it from the environment. # See https://github.com/googleapis/python-aiplatform/issues/852 # See https://github.com/googleapis/google-auth-library-python/issues/924 @@ -104,7 +104,7 @@ def _set_project_as_env_var_or_google_auth_default(self): self._credentials = self._credentials or credentials self._project = project - if not self._credentials: + if not self._credentials and not self._api_key: credentials, _ = google.auth.default() self._credentials = credentials @@ -117,6 +117,7 @@ def __init__(self): self._network = None self._service_account = None self._api_endpoint = None + self._api_key = None self._api_transport = None self._request_metadata = None self._resource_type = None @@ -137,6 +138,7 @@ def init( network: Optional[str] = None, service_account: Optional[str] = None, api_endpoint: Optional[str] = None, + api_key: Optional[str] = None, api_transport: Optional[str] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = None, ): @@ -197,6 +199,9 @@ def init( api_endpoint (str): Optional. The desired API endpoint, e.g., us-central1-aiplatform.googleapis.com + api_key (str): + Optional. The API key to use for service calls. + NOTE: Not all services support API keys. api_transport (str): Optional. The transport method which is either 'grpc' or 'rest'. NOTE: "rest" transport functionality is currently in a @@ -252,6 +257,8 @@ def init( self._service_account = service_account if request_metadata is not None: self._request_metadata = request_metadata + if api_key is not None: + self._api_key = api_key self._resource_type = None # Finally, perform secondary state updates @@ -304,6 +311,11 @@ def api_endpoint(self) -> Optional[str]: """Default API endpoint, if provided.""" return self._api_endpoint + @property + def api_key(self) -> Optional[str]: + """API Key, if provided.""" + return self._api_key + @property def project(self) -> str: """Default project.""" @@ -325,7 +337,7 @@ def project(self) -> str: except GoogleAuthError as exc: raise GoogleAuthError(project_not_found_exception_str) from exc - if not project_id: + if not project_id and not self.api_key: raise ValueError(project_not_found_exception_str) return project_id @@ -403,6 +415,7 @@ def get_client_options( location_override: Optional[str] = None, prediction_client: bool = False, api_base_path_override: Optional[str] = None, + api_key: Optional[str] = None, api_path_override: Optional[str] = None, ) -> client_options.ClientOptions: """Creates GAPIC client_options using location and type. @@ -414,6 +427,7 @@ def get_client_options( Vertex AI. prediction_client (str): Optional. flag to use a prediction endpoint. api_base_path_override (str): Optional. Override default API base path. + api_key (str): Optional. API key to use for the client. api_path_override (str): Optional. Override default api path. Returns: clients_options (google.api_core.client_options.ClientOptions): @@ -447,6 +461,11 @@ def get_client_options( else api_path_override ) + # Project/location take precedence over api_key + if api_key and not self._project: + return client_options.ClientOptions( + api_endpoint=api_endpoint, api_key=api_key + ) return client_options.ClientOptions(api_endpoint=api_endpoint) def common_location_path( @@ -479,6 +498,7 @@ def create_client( location_override: Optional[str] = None, prediction_client: bool = False, api_base_path_override: Optional[str] = None, + api_key: Optional[str] = None, api_path_override: Optional[str] = None, appended_user_agent: Optional[List[str]] = None, appended_gapic_version: Optional[str] = None, @@ -493,6 +513,7 @@ def create_client( Optional. Custom auth credentials. If not provided will use the current config. location_override (str): Optional. location override. prediction_client (str): Optional. flag to use a prediction endpoint. + api_key (str): Optional. API key to use for the client. api_base_path_override (str): Optional. Override default api base path. api_path_override (str): Optional. Override default api path. appended_user_agent (List[str]): @@ -539,6 +560,7 @@ def create_client( "client_options": self.get_client_options( location_override=location_override, prediction_client=prediction_client, + api_key=api_key, api_base_path_override=api_base_path_override, api_path_override=api_path_override, ), diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index b74226e6f5..f5e1cb37ff 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -126,6 +126,8 @@ def _get_resource_name_from_model_name( ) -> str: """Returns the full resource name starting with projects/ given a model name.""" if model_name.startswith("publishers/"): + if not project: + return model_name return f"projects/{project}/locations/{location}/{model_name}" elif model_name.startswith("projects/"): return model_name @@ -337,7 +339,7 @@ def __init__( location = aiplatform_utils.extract_project_and_location_from_parent( prediction_resource_name - )["location"] + ).get("location") self._model_name = model_name self._prediction_resource_name = prediction_resource_name