From bb77dd6ea08fc009591643e4fc9a1d4da30f7ffc Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 18 Mar 2024 21:59:48 +0000 Subject: [PATCH 1/6] feat: option to use bq connection without check --- bigframes/_config/bigquery_options.py | 33 ++++++- bigframes/clients.py | 6 -- bigframes/functions/remote_function.py | 34 +++++--- bigframes/ml/llm.py | 96 ++++++++++++--------- bigframes/ml/remote.py | 32 ++++--- bigframes/session/__init__.py | 1 + tests/system/small/test_remote_function.py | 34 ++++++++ tests/unit/_config/test_bigquery_options.py | 1 + 8 files changed, 162 insertions(+), 75 deletions(-) diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index 34701740f6..8a59193fba 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -33,19 +33,22 @@ class BigQueryOptions: def __init__( self, + *, credentials: Optional[google.auth.credentials.Credentials] = None, project: Optional[str] = None, location: Optional[str] = None, - bq_connection: Optional[str] = None, use_regional_endpoints: bool = False, + bq_connection: Optional[str] = None, + skip_bq_connection_check: bool = False, application_name: Optional[str] = None, kms_key_name: Optional[str] = None, ): self._credentials = credentials self._project = project self._location = location - self._bq_connection = bq_connection self._use_regional_endpoints = use_regional_endpoints + self._bq_connection = bq_connection + self._skip_bq_connection_check = skip_bq_connection_check self._application_name = application_name self._kms_key_name = kms_key_name self._session_started = False @@ -105,14 +108,16 @@ def project(self, value: Optional[str]): @property def bq_connection(self) -> Optional[str]: - """Name of the BigQuery connection to use. Should be of the form ... + """Name of the BigQuery connection to use. Should be of the form + ... You should either have the connection already created in the location you have chosen, or you should have the Project IAM Admin role to enable the service to create the connection for you if you need it. - If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection. + If this option isn't provided, or project or location aren't provided, + session will use its default project/location/connection_id as default connection. """ return self._bq_connection @@ -122,6 +127,26 @@ def bq_connection(self, value: Optional[str]): raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="bq_connection")) self._bq_connection = value + @property + def skip_bq_connection_check(self) -> bool: + """Forcibly use the BigQuery connection. + + Setting this flag to True would avoid creating the BigQuery connection + and checking or setting IAM permissions on it. So if the BigQuery + connection (default or user-provided) does not exist, or it does not have + necessary permissions set up to support BigQuery DataFrames operations, + then a runtime error will be reported. + """ + return self._skip_bq_connection_check + + @skip_bq_connection_check.setter + def skip_bq_connection_check(self, value: bool): + if self._session_started and self._skip_bq_connection_check != value: + raise ValueError( + SESSION_STARTED_MESSAGE.format(attribute="skip_bq_connection_check") + ) + self._skip_bq_connection_check = value + @property def use_regional_endpoints(self) -> bool: """Flag to connect to regional API endpoints. diff --git a/bigframes/clients.py b/bigframes/clients.py index de2421e499..abdcc7b80d 100644 --- a/bigframes/clients.py +++ b/bigframes/clients.py @@ -73,12 +73,6 @@ def create_bq_connection( iam_role: str of the IAM role that the service account of the created connection needs to aquire. E.g. 'run.invoker', 'aiplatform.user' """ - # TODO(shobs): The below command to enable BigQuery Connection API needs - # to be automated. Disabling for now since most target users would not - # have the privilege to enable API in a project. - # log("Making sure BigQuery Connection API is enabled") - # if os.system("gcloud services enable bigqueryconnection.googleapis.com"): - # raise ValueError("Failed to enable BigQuery Connection API") # If the intended connection does not exist then create it service_account_id = self._get_service_account_if_connection_exists( project_id, location, connection_id diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index 09a9d97869..e61a038e45 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -126,9 +126,8 @@ def __init__( bq_location, bq_dataset, bq_client, - bq_connection_client, bq_connection_id, - cloud_resource_manager_client, + bq_connection_manager, cloud_function_service_account, cloud_function_kms_key_name, cloud_function_docker_repository, @@ -140,9 +139,7 @@ def __init__( self._bq_dataset = bq_dataset self._bq_client = bq_client self._bq_connection_id = bq_connection_id - self._bq_connection_manager = clients.BqConnectionManager( - bq_connection_client, cloud_resource_manager_client - ) + self._bq_connection_manager = bq_connection_manager self._cloud_function_service_account = cloud_function_service_account self._cloud_function_kms_key_name = cloud_function_kms_key_name self._cloud_function_docker_repository = cloud_function_docker_repository @@ -152,12 +149,13 @@ def create_bq_remote_function( ): """Create a BigQuery remote function given the artifacts of a user defined function and the http endpoint of a corresponding cloud function.""" - self._bq_connection_manager.create_bq_connection( - self._gcp_project_id, - self._bq_location, - self._bq_connection_id, - "run.invoker", - ) + if self._bq_connection_manager: + self._bq_connection_manager.create_bq_connection( + self._gcp_project_id, + self._bq_location, + self._bq_connection_id, + "run.invoker", + ) # Create BQ function # https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2 @@ -816,6 +814,17 @@ def remote_function( " For more details see https://cloud.google.com/functions/docs/securing/cmek#before_you_begin" ) + skip_bq_connection_check = ( + False if session is None else session._skip_bq_connection_check + ) + bq_connection_manager = ( + None + if skip_bq_connection_check + else clients.BqConnectionManager( + bigquery_connection_client, resource_manager_client + ) + ) + def wrapper(f): if not callable(f): raise TypeError("f must be callable, got {}".format(f)) @@ -832,9 +841,8 @@ def wrapper(f): bq_location, dataset_ref.dataset_id, bigquery_client, - bigquery_connection_client, bq_connection_id, - resource_manager_client, + bq_connection_manager, cloud_function_service_account, cloud_function_kms_key_name, cloud_function_docker_repository, diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 79f6b90bfd..a3ae77f42d 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -73,12 +73,16 @@ def __init__( ): self.model_name = model_name self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient + self._bq_connection_manager = ( + None + if self.session._skip_bq_connection_check + else clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) ) connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -93,17 +97,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) if self.model_name not in _TEXT_GENERATOR_ENDPOINTS: raise ValueError( @@ -289,12 +295,16 @@ def __init__( self.model_name = model_name self.version = version self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient + self._bq_connection_manager = ( + None + if self.session._skip_bq_connection_check + else clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) ) connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -309,17 +319,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS: raise ValueError( @@ -437,12 +449,16 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient + self._bq_connection_manager = ( + None + if self.session._skip_bq_connection_check + else clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) ) connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -457,17 +473,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) options = {"endpoint": _GEMINI_PRO_ENDPOINT} diff --git a/bigframes/ml/remote.py b/bigframes/ml/remote.py index 2b83382e68..bfe1a91241 100644 --- a/bigframes/ml/remote.py +++ b/bigframes/ml/remote.py @@ -62,11 +62,15 @@ def __init__( self.output = output self.session = session or bpd.get_global_session() - self._bq_connection_manager = clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient + self._bq_connection_manager = ( + None + if self.session._skip_bq_connection_check + else clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) ) connection_name = connection_name or self.session._bq_connection - self.connection_name = self._bq_connection_manager.resolve_full_connection_name( + self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -81,17 +85,19 @@ def _create_bqml_model(self): raise ValueError( "Must provide connection_name, either in constructor or through session options." ) - connection_name_parts = self.connection_name.split(".") - if len(connection_name_parts) != 3: - raise ValueError( - f"connection_name must be of the format .., got {self.connection_name}." + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", ) - self._bq_connection_manager.create_bq_connection( - project_id=connection_name_parts[0], - location=connection_name_parts[1], - connection_id=connection_name_parts[2], - iam_role="aiplatform.user", - ) options = { "endpoint": self.endpoint, diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 40831292de..01e8da95ec 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -196,6 +196,7 @@ def __init__( # Resolve the BQ connection for remote function and Vertex AI integration self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID + self._skip_bq_connection_check = context._skip_bq_connection_check # Now that we're starting the session, don't allow the options to be # changed. diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index 1cf494ea6b..e7e434dbd0 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import google.api_core.exceptions from google.cloud import bigquery import pandas as pd import pytest @@ -466,6 +467,39 @@ def add_one(x): ) +def test_skip_bq_connection_check(dataset_id_permanent): + connection_name = "connection_does_not_exist" + session = bigframes.Session( + context=bigframes.BigQueryOptions( + bq_connection=connection_name, skip_bq_connection_check=True + ) + ) + + # Make sure that the connection does not exist + with pytest.raises(google.api_core.exceptions.NotFound): + session.bqconnectionclient.get_connection( + name=session.bqconnectionclient.connection_path( + session._project, session._location, connection_name + ) + ) + + # Make sure that an attempt to create a remote function routine with + # non-existent connection would result in an exception thrown by the BQ + # service. + # This is different from the exception throw by the BQ Connection service + # if it was not able to create the connection because of lack of permission + # when skip_bq_connection_check was not set to True: + # google.api_core.exceptions.PermissionDenied: 403 Permission 'resourcemanager.projects.setIamPolicy' denied on resource + with pytest.raises( + google.api_core.exceptions.NotFound, + match=f"Not found: Connection {connection_name}", + ): + + @session.remote_function([int], int, dataset=dataset_id_permanent) + def add_one(x): + return x + 1 + + @pytest.mark.flaky(retries=2, delay=120) def test_read_gbq_function_detects_invalid_function(bigquery_client, dataset_id): dataset_ref = bigquery.DatasetReference.from_string(dataset_id) diff --git a/tests/unit/_config/test_bigquery_options.py b/tests/unit/_config/test_bigquery_options.py index 1ce70e3da2..cf13084610 100644 --- a/tests/unit/_config/test_bigquery_options.py +++ b/tests/unit/_config/test_bigquery_options.py @@ -30,6 +30,7 @@ ("bq_connection", "path/to/connection/1", "path/to/connection/2"), ("use_regional_endpoints", False, True), ("kms_key_name", "kms/key/name/1", "kms/key/name/2"), + ("skip_bq_connection_check", False, True), ], ) def test_setter_raises_if_session_started(attribute, original_value, new_value): From de5e6faf4f6dd8b0d2e6ca3ed5971afb3dc9c718 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Wed, 20 Mar 2024 21:12:06 +0000 Subject: [PATCH 2/6] revert breaking signature change, centralize connection manager skipping --- bigframes/_config/bigquery_options.py | 9 +++--- bigframes/clients.py | 40 +++++++++++++++----------- bigframes/functions/remote_function.py | 13 ++------- bigframes/ml/llm.py | 30 ++++--------------- bigframes/ml/remote.py | 10 ++----- bigframes/session/__init__.py | 17 +++++++++-- tests/unit/test_clients.py | 8 +++--- 7 files changed, 55 insertions(+), 72 deletions(-) diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index 8a59193fba..d035fe5df1 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -33,24 +33,23 @@ class BigQueryOptions: def __init__( self, - *, credentials: Optional[google.auth.credentials.Credentials] = None, project: Optional[str] = None, location: Optional[str] = None, - use_regional_endpoints: bool = False, bq_connection: Optional[str] = None, - skip_bq_connection_check: bool = False, + use_regional_endpoints: bool = False, application_name: Optional[str] = None, kms_key_name: Optional[str] = None, + skip_bq_connection_check: bool = False, ): self._credentials = credentials self._project = project self._location = location - self._use_regional_endpoints = use_regional_endpoints self._bq_connection = bq_connection - self._skip_bq_connection_check = skip_bq_connection_check + self._use_regional_endpoints = use_regional_endpoints self._application_name = application_name self._kms_key_name = kms_key_name + self._skip_bq_connection_check = skip_bq_connection_check self._session_started = False @property diff --git a/bigframes/clients.py b/bigframes/clients.py index abdcc7b80d..116d0821e5 100644 --- a/bigframes/clients.py +++ b/bigframes/clients.py @@ -27,6 +27,23 @@ logger = logging.getLogger(__name__) +def resolve_full_bq_connection_name( + connection_name: str, default_project: str, default_location: str +) -> str: + """Retrieve the full connection name of the form ... + Use default project, location or connection_id when any of them are missing.""" + if connection_name.count(".") == 2: + return connection_name + + if connection_name.count(".") == 1: + return f"{default_project}.{connection_name}" + + if connection_name.count(".") == 0: + return f"{default_project}.{default_location}.{connection_name}" + + raise ValueError(f"Invalid connection name format: {connection_name}.") + + class BqConnectionManager: """Manager to handle operations with BQ connections.""" @@ -41,23 +58,6 @@ def __init__( self._bq_connection_client = bq_connection_client self._cloud_resource_manager_client = cloud_resource_manager_client - @classmethod - def resolve_full_connection_name( - cls, connection_name: str, default_project: str, default_location: str - ) -> str: - """Retrieve the full connection name of the form ... - Use default project, location or connection_id when any of them are missing.""" - if connection_name.count(".") == 2: - return connection_name - - if connection_name.count(".") == 1: - return f"{default_project}.{connection_name}" - - if connection_name.count(".") == 0: - return f"{default_project}.{default_location}.{connection_name}" - - raise ValueError(f"Invalid connection name format: {connection_name}.") - def create_bq_connection( self, project_id: str, location: str, connection_id: str, iam_role: str ): @@ -116,6 +116,12 @@ def _ensure_iam_binding( project = f"projects/{project_id}" service_account = f"serviceAccount:{service_account_id}" role = f"roles/{iam_role}" + # request = iam_policy_pb2.TestIamPermissionsRequest() + self._cloud_resource_manager_client.test_iam_permissions( + resource=project, + permissions=[role], + ) + request = iam_policy_pb2.GetIamPolicyRequest(resource=project) policy = self._cloud_resource_manager_client.get_iam_policy(request=request) diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index e61a038e45..9a1c6c7db4 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -782,7 +782,7 @@ def remote_function( if not bigquery_connection: bigquery_connection = session._bq_connection # type: ignore - bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name( + bigquery_connection = clients.resolve_full_bq_connection_name( bigquery_connection, default_project=dataset_ref.project, default_location=bq_location, @@ -814,16 +814,7 @@ def remote_function( " For more details see https://cloud.google.com/functions/docs/securing/cmek#before_you_begin" ) - skip_bq_connection_check = ( - False if session is None else session._skip_bq_connection_check - ) - bq_connection_manager = ( - None - if skip_bq_connection_check - else clients.BqConnectionManager( - bigquery_connection_client, resource_manager_client - ) - ) + bq_connection_manager = False if session is None else session.bqconnectionmanager def wrapper(f): if not callable(f): diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 3a0f29be04..f2ca7a98f6 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -73,16 +73,10 @@ def __init__( ): self.model_name = model_name self.session = session or bpd.get_global_session() - self._bq_connection_manager = ( - None - if self.session._skip_bq_connection_check - else clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) - ) + self._bq_connection_manager = self.session._bq_connection_manager connection_name = connection_name or self.session._bq_connection - self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -295,16 +289,10 @@ def __init__( self.model_name = model_name self.version = version self.session = session or bpd.get_global_session() - self._bq_connection_manager = ( - None - if self.session._skip_bq_connection_check - else clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) - ) + self._bq_connection_manager = self.session._bq_connection_manager connection_name = connection_name or self.session._bq_connection - self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, @@ -449,16 +437,10 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self._bq_connection_manager = ( - None - if self.session._skip_bq_connection_check - else clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) - ) + self._bq_connection_manager = self.session._bq_connection_manager connection_name = connection_name or self.session._bq_connection - self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, diff --git a/bigframes/ml/remote.py b/bigframes/ml/remote.py index bfe1a91241..8cf892f536 100644 --- a/bigframes/ml/remote.py +++ b/bigframes/ml/remote.py @@ -62,15 +62,9 @@ def __init__( self.output = output self.session = session or bpd.get_global_session() - self._bq_connection_manager = ( - None - if self.session._skip_bq_connection_check - else clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient - ) - ) + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection - self.connection_name = clients.BqConnectionManager.resolve_full_connection_name( + self.connection_name = clients.resolve_full_bq_connection_name( connection_name, default_project=self.session._project, default_location=self.session._location, diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 0f44154693..c31994ce7a 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -39,8 +39,6 @@ ) import warnings -# Even though the ibis.backends.bigquery import is unused, it's needed -# to register new and replacement ops with the Ibis BigQuery backend. import bigframes_vendored.ibis.backends.bigquery # noqa import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops import bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq @@ -71,7 +69,10 @@ StorageOptions, ) +# Even though the ibis.backends.bigquery import is unused, it's needed +# to register new and replacement ops with the Ibis BigQuery backend. import bigframes._config.bigquery_options as bigquery_options +import bigframes.clients import bigframes.constants as constants import bigframes.core as core import bigframes.core.blocks as blocks @@ -137,7 +138,7 @@ class Session( Configuration adjusting how to connect to BigQuery and related APIs. Note that some options are ignored if ``clients_provider`` is set. - clients_provider (bigframes.session.bigframes.session.clients.ClientsProvider): + clients_provider (bigframes.session.clients.ClientsProvider): An object providing client library objects. """ @@ -223,6 +224,16 @@ def cloudfunctionsclient(self): def resourcemanagerclient(self): return self._clients_provider.resourcemanagerclient + _bq_connection_manager: Optional[bigframes.clients.BqConnectionManager] = None + + @property + def bqconnectionmanager(self): + if not self._skip_bq_connection_check and not self._bq_connection_manager: + self._bq_connection_manager = bigframes.clients.BqConnectionManager( + self.session.bqconnectionclient, self.session.resourcemanagerclient + ) + return self._bq_connection_manager + @property def _project(self): return self.bqclient.project diff --git a/tests/unit/test_clients.py b/tests/unit/test_clients.py index f89cc21397..37450ececb 100644 --- a/tests/unit/test_clients.py +++ b/tests/unit/test_clients.py @@ -18,21 +18,21 @@ def test_get_connection_name_full_connection_id(): - connection_name = clients.BqConnectionManager.resolve_full_connection_name( + connection_name = clients.resolve_full_bq_connection_name( "connection-id", default_project="default-project", default_location="us" ) assert connection_name == "default-project.us.connection-id" def test_get_connection_name_full_location_connection_id(): - connection_name = clients.BqConnectionManager.resolve_full_connection_name( + connection_name = clients.resolve_full_bq_connection_name( "eu.connection-id", default_project="default-project", default_location="us" ) assert connection_name == "default-project.eu.connection-id" def test_get_connection_name_full_all(): - connection_name = clients.BqConnectionManager.resolve_full_connection_name( + connection_name = clients.resolve_full_bq_connection_name( "my-project.eu.connection-id", default_project="default-project", default_location="us", @@ -42,7 +42,7 @@ def test_get_connection_name_full_all(): def test_get_connection_name_full_raise_value_error(): with pytest.raises(ValueError): - clients.BqConnectionManager.resolve_full_connection_name( + clients.resolve_full_bq_connection_name( "my-project.eu.connection-id.extra_field", default_project="default-project", default_location="us", From 3da2f6b1d224bafc5119451d3646107d282ab62a Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Wed, 20 Mar 2024 22:04:44 +0000 Subject: [PATCH 3/6] fix bad referencing --- bigframes/session/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index c31994ce7a..042066844e 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -230,7 +230,7 @@ def resourcemanagerclient(self): def bqconnectionmanager(self): if not self._skip_bq_connection_check and not self._bq_connection_manager: self._bq_connection_manager = bigframes.clients.BqConnectionManager( - self.session.bqconnectionclient, self.session.resourcemanagerclient + self.bqconnectionclient, self.resourcemanagerclient ) return self._bq_connection_manager From 84037d5fdb4869e785bb19640744373b9149be62 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Wed, 20 Mar 2024 22:08:52 +0000 Subject: [PATCH 4/6] use public property from session --- bigframes/ml/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f2ca7a98f6..6c4ae2ea43 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -73,7 +73,7 @@ def __init__( ): self.model_name = model_name self.session = session or bpd.get_global_session() - self._bq_connection_manager = self.session._bq_connection_manager + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection self.connection_name = clients.resolve_full_bq_connection_name( @@ -289,7 +289,7 @@ def __init__( self.model_name = model_name self.version = version self.session = session or bpd.get_global_session() - self._bq_connection_manager = self.session._bq_connection_manager + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection self.connection_name = clients.resolve_full_bq_connection_name( @@ -437,7 +437,7 @@ def __init__( connection_name: Optional[str] = None, ): self.session = session or bpd.get_global_session() - self._bq_connection_manager = self.session._bq_connection_manager + self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection self.connection_name = clients.resolve_full_bq_connection_name( From 6fcb6b75a523fb47bacc9c7c4cef4779647176d5 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Wed, 20 Mar 2024 22:36:09 +0000 Subject: [PATCH 5/6] revert unintended test_iam_permissions change --- bigframes/clients.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/bigframes/clients.py b/bigframes/clients.py index 116d0821e5..8a2dbfed6c 100644 --- a/bigframes/clients.py +++ b/bigframes/clients.py @@ -116,12 +116,6 @@ def _ensure_iam_binding( project = f"projects/{project_id}" service_account = f"serviceAccount:{service_account_id}" role = f"roles/{iam_role}" - # request = iam_policy_pb2.TestIamPermissionsRequest() - self._cloud_resource_manager_client.test_iam_permissions( - resource=project, - permissions=[role], - ) - request = iam_policy_pb2.GetIamPolicyRequest(resource=project) policy = self._cloud_resource_manager_client.get_iam_policy(request=request) From 467eb031acd98fed4e8d2bb02ca445d1ecb318b7 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Wed, 20 Mar 2024 22:43:58 +0000 Subject: [PATCH 6/6] fix couple of more unwanted changes --- bigframes/functions/remote_function.py | 2 +- bigframes/session/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index 2f20ac38cf..178c911591 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -814,7 +814,7 @@ def remote_function( " For more details see https://cloud.google.com/functions/docs/securing/cmek#before_you_begin" ) - bq_connection_manager = False if session is None else session.bqconnectionmanager + bq_connection_manager = None if session is None else session.bqconnectionmanager def wrapper(f): if not callable(f): diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 53acfeeeae..5732d4b08e 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -39,6 +39,8 @@ ) import warnings +# Even though the ibis.backends.bigquery import is unused, it's needed +# to register new and replacement ops with the Ibis BigQuery backend. import bigframes_vendored.ibis.backends.bigquery # noqa import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops import bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq @@ -70,8 +72,6 @@ ) import pyarrow as pa -# Even though the ibis.backends.bigquery import is unused, it's needed -# to register new and replacement ops with the Ibis BigQuery backend. import bigframes._config.bigquery_options as bigquery_options import bigframes.clients import bigframes.constants as constants