From 0b3f8e5ce63f75ba99ee8cf29226a0fd38bef99f Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Thu, 21 Mar 2024 01:47:20 +0000 Subject: [PATCH] feat: option to use bq connection without check (#460) * feat: option to use bq connection without check * revert breaking signature change, centralize connection manager skipping * fix bad referencing * use public property from session * revert unintended test_iam_permissions change * fix couple of more unwanted changes --- bigframes/_config/bigquery_options.py | 28 ++++++- bigframes/clients.py | 40 +++++----- bigframes/functions/remote_function.py | 27 ++++--- bigframes/ml/llm.py | 84 ++++++++++----------- bigframes/ml/remote.py | 28 +++---- bigframes/session/__init__.py | 14 +++- tests/system/small/test_remote_function.py | 34 +++++++++ tests/unit/_config/test_bigquery_options.py | 1 + tests/unit/test_clients.py | 8 +- 9 files changed, 164 insertions(+), 100 deletions(-) diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index 34701740f6..d035fe5df1 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -40,6 +40,7 @@ def __init__( use_regional_endpoints: bool = False, application_name: Optional[str] = None, kms_key_name: Optional[str] = None, + skip_bq_connection_check: bool = False, ): self._credentials = credentials self._project = project @@ -48,6 +49,7 @@ def __init__( self._use_regional_endpoints = use_regional_endpoints self._application_name = application_name self._kms_key_name = kms_key_name + self._skip_bq_connection_check = skip_bq_connection_check self._session_started = False @property @@ -105,14 +107,16 @@ def project(self, value: Optional[str]): @property def bq_connection(self) -> Optional[str]: - """Name of the BigQuery connection to use. Should be of the form ... + """Name of the BigQuery connection to use. Should be of the form + ... You should either have the connection already created in the location you have chosen, or you should have the Project IAM Admin role to enable the service to create the connection for you if you need it. - If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection. + If this option isn't provided, or project or location aren't provided, + session will use its default project/location/connection_id as default connection. """ return self._bq_connection @@ -122,6 +126,26 @@ def bq_connection(self, value: Optional[str]): raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="bq_connection")) self._bq_connection = value + @property + def skip_bq_connection_check(self) -> bool: + """Forcibly use the BigQuery connection. + + Setting this flag to True would avoid creating the BigQuery connection + and checking or setting IAM permissions on it. So if the BigQuery + connection (default or user-provided) does not exist, or it does not have + necessary permissions set up to support BigQuery DataFrames operations, + then a runtime error will be reported. + """ + return self._skip_bq_connection_check + + @skip_bq_connection_check.setter + def skip_bq_connection_check(self, value: bool): + if self._session_started and self._skip_bq_connection_check != value: + raise ValueError( + SESSION_STARTED_MESSAGE.format(attribute="skip_bq_connection_check") + ) + self._skip_bq_connection_check = value + @property def use_regional_endpoints(self) -> bool: """Flag to connect to regional API endpoints. diff --git a/bigframes/clients.py b/bigframes/clients.py index de2421e499..8a2dbfed6c 100644 --- a/bigframes/clients.py +++ b/bigframes/clients.py @@ -27,6 +27,23 @@ logger = logging.getLogger(__name__) +def resolve_full_bq_connection_name( + connection_name: str, default_project: str, default_location: str +) -> str: + """Retrieve the full connection name of the form ... + Use default project, location or connection_id when any of them are missing.""" + if connection_name.count(".") == 2: + return connection_name + + if connection_name.count(".") == 1: + return f"{default_project}.{connection_name}" + + if connection_name.count(".") == 0: + return f"{default_project}.{default_location}.{connection_name}" + + raise ValueError(f"Invalid connection name format: {connection_name}.") + + class BqConnectionManager: """Manager to handle operations with BQ connections.""" @@ -41,23 +58,6 @@ def __init__( self._bq_connection_client = bq_connection_client self._cloud_resource_manager_client = cloud_resource_manager_client - @classmethod - def resolve_full_connection_name( - cls, connection_name: str, default_project: str, default_location: str - ) -> str: - """Retrieve the full connection name of the form ... - Use default project, location or connection_id when any of them are missing.""" - if connection_name.count(".") == 2: - return connection_name - - if connection_name.count(".") == 1: - return f"{default_project}.{connection_name}" - - if connection_name.count(".") == 0: - return f"{default_project}.{default_location}.{connection_name}" - - raise ValueError(f"Invalid connection name format: {connection_name}.") - def create_bq_connection( self, project_id: str, location: str, connection_id: str, iam_role: str ): @@ -73,12 +73,6 @@ def create_bq_connection( iam_role: str of the IAM role that the service account of the created connection needs to aquire. E.g. 'run.invoker', 'aiplatform.user' """ - # TODO(shobs): The below command to enable BigQuery Connection API needs - # to be automated. Disabling for now since most target users would not - # have the privilege to enable API in a project. - # log("Making sure BigQuery Connection API is enabled") - # if os.system("gcloud services enable bigqueryconnection.googleapis.com"): - # raise ValueError("Failed to enable BigQuery Connection API") # If the intended connection does not exist then create it service_account_id = self._get_service_account_if_connection_exists( project_id, location, connection_id diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index bfb272d992..178c911591 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -126,9 +126,8 @@ def __init__( bq_location, bq_dataset, bq_client, - bq_connection_client, bq_connection_id, - cloud_resource_manager_client, + bq_connection_manager, cloud_function_service_account, cloud_function_kms_key_name, cloud_function_docker_repository, @@ -140,9 +139,7 @@ def __init__( self._bq_dataset = bq_dataset self._bq_client = bq_client self._bq_connection_id = bq_connection_id - self._bq_connection_manager = clients.BqConnectionManager( - bq_connection_client, cloud_resource_manager_client - ) + self._bq_connection_manager = bq_connection_manager self._cloud_function_service_account = cloud_function_service_account self._cloud_function_kms_key_name = cloud_function_kms_key_name self._cloud_function_docker_repository = cloud_function_docker_repository @@ -152,12 +149,13 @@ def create_bq_remote_function( ): """Create a BigQuery remote function given the artifacts of a user defined function and the http endpoint of a corresponding cloud function.""" - self._bq_connection_manager.create_bq_connection( - self._gcp_project_id, - self._bq_location, - self._bq_connection_id, - "run.invoker", - ) + if self._bq_connection_manager: + self._bq_connection_manager.create_bq_connection( + self._gcp_project_id, + self._bq_location, + self._bq_connection_id, + "run.invoker", + ) # Create BQ function # https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2 @@ -784,7 +782,7 @@ def remote_function( if not bigquery_connection: bigquery_connection = session._bq_connection # type: ignore - bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name( + bigquery_connection = clients.resolve_full_bq_connection_name( bigquery_connection, default_project=dataset_ref.project, default_location=bq_location, @@ -816,6 +814,8 @@ def remote_function( " For more details see https://cloud.google.com/functions/docs/securing/cmek#before_you_begin" ) + bq_connection_manager = None if session is None else session.bqconnectionmanager + def wrapper(f): if not callable(f): raise TypeError("f must be callable, got {}".format(f)) @@ -832,9 +832,8 @@ def wrapper(f): bq_location, dataset_ref.dataset_id, bigquery_client, - bigquery_connection_client, bq_connection_id, - resource_manager_client, + bq_connection_manager, cloud_function_service_account, cloud_function_kms_key_name, cloud_function_docker_repository, diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 10c3cc51b2..6c4ae2ea43 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -73,12 +73,10 @@ def __init__( ): self.model_name = model_name self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -93,17 +91,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) if self.model_name not in _TEXT_GENERATOR_ENDPOINTS: raise ValueError( @@ -289,12 +289,10 @@ def __init__( self.model_name = model_name self.version = version self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -309,17 +307,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS: raise ValueError( @@ -437,12 +437,10 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -457,17 +455,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) options = {"endpoint": _GEMINI_PRO_ENDPOINT} diff --git a/bigframes/ml/remote.py b/bigframes/ml/remote.py index 2b83382e68..8cf892f536 100644 --- a/bigframes/ml/remote.py +++ b/bigframes/ml/remote.py @@ -62,11 +62,9 @@ def __init__( self.output = output self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -81,17 +79,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) options = { "endpoint": self.endpoint, diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 3f59e58df2..5732d4b08e 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -73,6 +73,7 @@ import pyarrow as pa import bigframes._config.bigquery_options as bigquery_options +import bigframes.clients import bigframes.constants as constants import bigframes.core as core import bigframes.core.blocks as blocks @@ -153,7 +154,7 @@ class Session( Configuration adjusting how to connect to BigQuery and related APIs. Note that some options are ignored if ``clients_provider`` is set. - clients_provider (bigframes.session.bigframes.session.clients.ClientsProvider): + clients_provider (bigframes.session.clients.ClientsProvider): An object providing client library objects. """ @@ -212,6 +213,7 @@ def __init__( # Resolve the BQ connection for remote function and Vertex AI integration self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID + self._skip_bq_connection_check = context._skip_bq_connection_check # Now that we're starting the session, don't allow the options to be # changed. @@ -238,6 +240,16 @@ def cloudfunctionsclient(self): def resourcemanagerclient(self): return self._clients_provider.resourcemanagerclient + _bq_connection_manager: Optional[bigframes.clients.BqConnectionManager] = None + + @property + def bqconnectionmanager(self): + if not self._skip_bq_connection_check and not self._bq_connection_manager: + self._bq_connection_manager = bigframes.clients.BqConnectionManager( + self.bqconnectionclient, self.resourcemanagerclient + ) + return self._bq_connection_manager + @property def _project(self): return self.bqclient.project diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index 1cf494ea6b..e7e434dbd0 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import google.api_core.exceptions from google.cloud import bigquery import pandas as pd import pytest @@ -466,6 +467,39 @@ def add_one(x): ) +def test_skip_bq_connection_check(dataset_id_permanent): + connection_name = "connection_does_not_exist" + session = bigframes.Session( + context=bigframes.BigQueryOptions( + bq_connection=connection_name, skip_bq_connection_check=True + ) + ) + + # Make sure that the connection does not exist + with pytest.raises(google.api_core.exceptions.NotFound): + session.bqconnectionclient.get_connection( + name=session.bqconnectionclient.connection_path( + session._project, session._location, connection_name + ) + ) + + # Make sure that an attempt to create a remote function routine with + # non-existent connection would result in an exception thrown by the BQ + # service. + # This is different from the exception throw by the BQ Connection service + # if it was not able to create the connection because of lack of permission + # when skip_bq_connection_check was not set to True: + # google.api_core.exceptions.PermissionDenied: 403 Permission 'resourcemanager.projects.setIamPolicy' denied on resource + with pytest.raises( + google.api_core.exceptions.NotFound, + match=f"Not found: Connection {connection_name}", + ): + + @session.remote_function([int], int, dataset=dataset_id_permanent) + def add_one(x): + return x + 1 + + @pytest.mark.flaky(retries=2, delay=120) def test_read_gbq_function_detects_invalid_function(bigquery_client, dataset_id): dataset_ref = bigquery.DatasetReference.from_string(dataset_id) diff --git a/tests/unit/_config/test_bigquery_options.py b/tests/unit/_config/test_bigquery_options.py index 1ce70e3da2..cf13084610 100644 --- a/tests/unit/_config/test_bigquery_options.py +++ b/tests/unit/_config/test_bigquery_options.py @@ -30,6 +30,7 @@ ("bq_connection", "path/to/connection/1", "path/to/connection/2"), ("use_regional_endpoints", False, True), ("kms_key_name", "kms/key/name/1", "kms/key/name/2"), + ("skip_bq_connection_check", False, True), ], ) def test_setter_raises_if_session_started(attribute, original_value, new_value): diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py index f89cc21397..37450ececb 100644 --- a/tests/unit/test_clients.py +++ b/tests/unit/test_clients.py @@ -18,21 +18,21 @@ def test_get_connection_name_full_connection_id(): - connection_name = clients.BqConnectionManager.resolve_full_connection_name( + connection_name = clients.resolve_full_bq_connection_name( "connection-id", default_project="default-project", default_location="us" ) assert connection_name == "default-project.us.connection-id" def test_get_connection_name_full_location_connection_id(): - connection_name = clients.BqConnectionManager.resolve_full_connection_name( + connection_name = clients.resolve_full_bq_connection_name( "eu.connection-id", default_project="default-project", default_location="us" ) assert connection_name == "default-project.eu.connection-id" def test_get_connection_name_full_all(): - connection_name = clients.BqConnectionManager.resolve_full_connection_name( + connection_name = clients.resolve_full_bq_connection_name( "my-project.eu.connection-id", default_project="default-project", default_location="us", @@ -42,7 +42,7 @@ def test_get_connection_name_full_all(): def test_get_connection_name_full_raise_value_error(): with pytest.raises(ValueError): - clients.BqConnectionManager.resolve_full_connection_name( + clients.resolve_full_bq_connection_name( "my-project.eu.connection-id.extra_field", default_project="default-project", default_location="us",