Skip to content

Commit

Permalink
feat: option to use bq connection without check (#460)
Browse files Browse the repository at this point in the history
* feat: option to use bq connection without check

* revert breaking signature change, centralize connection manager skipping

* fix bad referencing

* use public property from session

* revert unintended test_iam_permissions change

* fix couple of more unwanted changes
  • Loading branch information
shobsi authored Mar 21, 2024
1 parent a5345fe commit 0b3f8e5
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 100 deletions.
28 changes: 26 additions & 2 deletions bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
use_regional_endpoints: bool = False,
application_name: Optional[str] = None,
kms_key_name: Optional[str] = None,
skip_bq_connection_check: bool = False,
):
self._credentials = credentials
self._project = project
Expand All @@ -48,6 +49,7 @@ def __init__(
self._use_regional_endpoints = use_regional_endpoints
self._application_name = application_name
self._kms_key_name = kms_key_name
self._skip_bq_connection_check = skip_bq_connection_check
self._session_started = False

@property
Expand Down Expand Up @@ -105,14 +107,16 @@ def project(self, value: Optional[str]):

@property
def bq_connection(self) -> Optional[str]:
"""Name of the BigQuery connection to use. Should be of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
"""Name of the BigQuery connection to use. Should be of the form
<PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
You should either have the connection already created in the
<code>location</code> you have chosen, or you should have the Project IAM
Admin role to enable the service to create the connection for you if you
need it.
If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection.
If this option isn't provided, or project or location aren't provided,
session will use its default project/location/connection_id as default connection.
"""
return self._bq_connection

Expand All @@ -122,6 +126,26 @@ def bq_connection(self, value: Optional[str]):
raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="bq_connection"))
self._bq_connection = value

@property
def skip_bq_connection_check(self) -> bool:
"""Forcibly use the BigQuery connection.
Setting this flag to True would avoid creating the BigQuery connection
and checking or setting IAM permissions on it. So if the BigQuery
connection (default or user-provided) does not exist, or it does not have
necessary permissions set up to support BigQuery DataFrames operations,
then a runtime error will be reported.
"""
return self._skip_bq_connection_check

@skip_bq_connection_check.setter
def skip_bq_connection_check(self, value: bool):
if self._session_started and self._skip_bq_connection_check != value:
raise ValueError(
SESSION_STARTED_MESSAGE.format(attribute="skip_bq_connection_check")
)
self._skip_bq_connection_check = value

@property
def use_regional_endpoints(self) -> bool:
"""Flag to connect to regional API endpoints.
Expand Down
40 changes: 17 additions & 23 deletions bigframes/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@
logger = logging.getLogger(__name__)


def resolve_full_bq_connection_name(
connection_name: str, default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")


class BqConnectionManager:
"""Manager to handle operations with BQ connections."""

Expand All @@ -41,23 +58,6 @@ def __init__(
self._bq_connection_client = bq_connection_client
self._cloud_resource_manager_client = cloud_resource_manager_client

@classmethod
def resolve_full_connection_name(
cls, connection_name: str, default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")

def create_bq_connection(
self, project_id: str, location: str, connection_id: str, iam_role: str
):
Expand All @@ -73,12 +73,6 @@ def create_bq_connection(
iam_role:
str of the IAM role that the service account of the created connection needs to aquire. E.g. 'run.invoker', 'aiplatform.user'
"""
# TODO(shobs): The below command to enable BigQuery Connection API needs
# to be automated. Disabling for now since most target users would not
# have the privilege to enable API in a project.
# log("Making sure BigQuery Connection API is enabled")
# if os.system("gcloud services enable bigqueryconnection.googleapis.com"):
# raise ValueError("Failed to enable BigQuery Connection API")
# If the intended connection does not exist then create it
service_account_id = self._get_service_account_if_connection_exists(
project_id, location, connection_id
Expand Down
27 changes: 13 additions & 14 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def __init__(
bq_location,
bq_dataset,
bq_client,
bq_connection_client,
bq_connection_id,
cloud_resource_manager_client,
bq_connection_manager,
cloud_function_service_account,
cloud_function_kms_key_name,
cloud_function_docker_repository,
Expand All @@ -140,9 +139,7 @@ def __init__(
self._bq_dataset = bq_dataset
self._bq_client = bq_client
self._bq_connection_id = bq_connection_id
self._bq_connection_manager = clients.BqConnectionManager(
bq_connection_client, cloud_resource_manager_client
)
self._bq_connection_manager = bq_connection_manager
self._cloud_function_service_account = cloud_function_service_account
self._cloud_function_kms_key_name = cloud_function_kms_key_name
self._cloud_function_docker_repository = cloud_function_docker_repository
Expand All @@ -152,12 +149,13 @@ def create_bq_remote_function(
):
"""Create a BigQuery remote function given the artifacts of a user defined
function and the http endpoint of a corresponding cloud function."""
self._bq_connection_manager.create_bq_connection(
self._gcp_project_id,
self._bq_location,
self._bq_connection_id,
"run.invoker",
)
if self._bq_connection_manager:
self._bq_connection_manager.create_bq_connection(
self._gcp_project_id,
self._bq_location,
self._bq_connection_id,
"run.invoker",
)

# Create BQ function
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2
Expand Down Expand Up @@ -784,7 +782,7 @@ def remote_function(
if not bigquery_connection:
bigquery_connection = session._bq_connection # type: ignore

bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name(
bigquery_connection = clients.resolve_full_bq_connection_name(
bigquery_connection,
default_project=dataset_ref.project,
default_location=bq_location,
Expand Down Expand Up @@ -816,6 +814,8 @@ def remote_function(
" For more details see https://cloud.google.com/functions/docs/securing/cmek#before_you_begin"
)

bq_connection_manager = None if session is None else session.bqconnectionmanager

def wrapper(f):
if not callable(f):
raise TypeError("f must be callable, got {}".format(f))
Expand All @@ -832,9 +832,8 @@ def wrapper(f):
bq_location,
dataset_ref.dataset_id,
bigquery_client,
bigquery_connection_client,
bq_connection_id,
resource_manager_client,
bq_connection_manager,
cloud_function_service_account,
cloud_function_kms_key_name,
cloud_function_docker_repository,
Expand Down
84 changes: 42 additions & 42 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,10 @@ def __init__(
):
self.model_name = model_name
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -93,17 +91,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

if self.model_name not in _TEXT_GENERATOR_ENDPOINTS:
raise ValueError(
Expand Down Expand Up @@ -289,12 +289,10 @@ def __init__(
self.model_name = model_name
self.version = version
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -309,17 +307,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS:
raise ValueError(
Expand Down Expand Up @@ -437,12 +437,10 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -457,17 +455,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

options = {"endpoint": _GEMINI_PRO_ENDPOINT}

Expand Down
28 changes: 14 additions & 14 deletions bigframes/ml/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,9 @@ def __init__(
self.output = output
self.session = session or bpd.get_global_session()

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager
connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -81,17 +79,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

options = {
"endpoint": self.endpoint,
Expand Down
Loading

0 comments on commit 0b3f8e5

Please sign in to comment.