Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: option to use bq connection without check #460

Merged
merged 10 commits into from
Mar 21, 2024
28 changes: 26 additions & 2 deletions bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
use_regional_endpoints: bool = False,
application_name: Optional[str] = None,
kms_key_name: Optional[str] = None,
skip_bq_connection_check: bool = False,
):
self._credentials = credentials
self._project = project
Expand All @@ -48,6 +49,7 @@ def __init__(
self._use_regional_endpoints = use_regional_endpoints
self._application_name = application_name
self._kms_key_name = kms_key_name
self._skip_bq_connection_check = skip_bq_connection_check
self._session_started = False

@property
Expand Down Expand Up @@ -105,14 +107,16 @@ def project(self, value: Optional[str]):

@property
def bq_connection(self) -> Optional[str]:
"""Name of the BigQuery connection to use. Should be of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
"""Name of the BigQuery connection to use. Should be of the form
<PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.

You should either have the connection already created in the
<code>location</code> you have chosen, or you should have the Project IAM
Admin role to enable the service to create the connection for you if you
need it.

If this option isn't provided, or project or location aren't provided, session will use its default project/location/connection_id as default connection.
If this option isn't provided, or project or location aren't provided,
session will use its default project/location/connection_id as default connection.
"""
return self._bq_connection

Expand All @@ -122,6 +126,26 @@ def bq_connection(self, value: Optional[str]):
raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="bq_connection"))
self._bq_connection = value

@property
def skip_bq_connection_check(self) -> bool:
"""Forcibly use the BigQuery connection.

Setting this flag to True would avoid creating the BigQuery connection
and checking or setting IAM permissions on it. So if the BigQuery
connection (default or user-provided) does not exist, or it does not have
necessary permissions set up to support BigQuery DataFrames operations,
then a runtime error will be reported.
"""
return self._skip_bq_connection_check

@skip_bq_connection_check.setter
def skip_bq_connection_check(self, value: bool):
if self._session_started and self._skip_bq_connection_check != value:
raise ValueError(
SESSION_STARTED_MESSAGE.format(attribute="skip_bq_connection_check")
)
self._skip_bq_connection_check = value

@property
def use_regional_endpoints(self) -> bool:
"""Flag to connect to regional API endpoints.
Expand Down
40 changes: 17 additions & 23 deletions bigframes/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@
logger = logging.getLogger(__name__)


def resolve_full_bq_connection_name(
connection_name: str, default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")


class BqConnectionManager:
"""Manager to handle operations with BQ connections."""

Expand All @@ -41,23 +58,6 @@ def __init__(
self._bq_connection_client = bq_connection_client
self._cloud_resource_manager_client = cloud_resource_manager_client

@classmethod
def resolve_full_connection_name(
cls, connection_name: str, default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")

def create_bq_connection(
self, project_id: str, location: str, connection_id: str, iam_role: str
):
Expand All @@ -73,12 +73,6 @@ def create_bq_connection(
iam_role:
str of the IAM role that the service account of the created connection needs to aquire. E.g. 'run.invoker', 'aiplatform.user'
"""
# TODO(shobs): The below command to enable BigQuery Connection API needs
# to be automated. Disabling for now since most target users would not
# have the privilege to enable API in a project.
# log("Making sure BigQuery Connection API is enabled")
# if os.system("gcloud services enable bigqueryconnection.googleapis.com"):
# raise ValueError("Failed to enable BigQuery Connection API")
# If the intended connection does not exist then create it
service_account_id = self._get_service_account_if_connection_exists(
project_id, location, connection_id
Expand Down
27 changes: 13 additions & 14 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def __init__(
bq_location,
bq_dataset,
bq_client,
bq_connection_client,
bq_connection_id,
cloud_resource_manager_client,
bq_connection_manager,
cloud_function_service_account,
cloud_function_kms_key_name,
cloud_function_docker_repository,
Expand All @@ -140,9 +139,7 @@ def __init__(
self._bq_dataset = bq_dataset
self._bq_client = bq_client
self._bq_connection_id = bq_connection_id
self._bq_connection_manager = clients.BqConnectionManager(
bq_connection_client, cloud_resource_manager_client
)
self._bq_connection_manager = bq_connection_manager
self._cloud_function_service_account = cloud_function_service_account
self._cloud_function_kms_key_name = cloud_function_kms_key_name
self._cloud_function_docker_repository = cloud_function_docker_repository
Expand All @@ -152,12 +149,13 @@ def create_bq_remote_function(
):
"""Create a BigQuery remote function given the artifacts of a user defined
function and the http endpoint of a corresponding cloud function."""
self._bq_connection_manager.create_bq_connection(
self._gcp_project_id,
self._bq_location,
self._bq_connection_id,
"run.invoker",
)
if self._bq_connection_manager:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think of creating a dummy ConnectionManager class, and when skip_bq_connection_check is True, use the dummy implementation instead of a real one. Then the same logic can apply to all the places, and the user (of the connection_manager) doesn't need to think about the implementations.

Well just a suggestion, up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um, gave it a thought, feels weird to implement a class with a method create_bq_connection which doesn't create anything. I'd prefer to not do that. Let me know if you have strong opinion about it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we'd rename the create_bq_connection to sth like prepare_bq_connection.

No need to put too much thoughts on it. Just a suggestion.

self._bq_connection_manager.create_bq_connection(
self._gcp_project_id,
self._bq_location,
self._bq_connection_id,
"run.invoker",
)

# Create BQ function
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2
Expand Down Expand Up @@ -784,7 +782,7 @@ def remote_function(
if not bigquery_connection:
bigquery_connection = session._bq_connection # type: ignore

bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name(
bigquery_connection = clients.resolve_full_bq_connection_name(
bigquery_connection,
default_project=dataset_ref.project,
default_location=bq_location,
Expand Down Expand Up @@ -816,6 +814,8 @@ def remote_function(
" For more details see https://cloud.google.com/functions/docs/securing/cmek#before_you_begin"
)

bq_connection_manager = None if session is None else session.bqconnectionmanager

def wrapper(f):
if not callable(f):
raise TypeError("f must be callable, got {}".format(f))
Expand All @@ -832,9 +832,8 @@ def wrapper(f):
bq_location,
dataset_ref.dataset_id,
bigquery_client,
bigquery_connection_client,
bq_connection_id,
resource_manager_client,
bq_connection_manager,
cloud_function_service_account,
cloud_function_kms_key_name,
cloud_function_docker_repository,
Expand Down
84 changes: 42 additions & 42 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,10 @@ def __init__(
):
self.model_name = model_name
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -93,17 +91,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

if self.model_name not in _TEXT_GENERATOR_ENDPOINTS:
raise ValueError(
Expand Down Expand Up @@ -289,12 +289,10 @@ def __init__(
self.model_name = model_name
self.version = version
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -309,17 +307,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS:
raise ValueError(
Expand Down Expand Up @@ -437,12 +437,10 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -457,17 +455,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

options = {"endpoint": _GEMINI_PRO_ENDPOINT}

Expand Down
28 changes: 14 additions & 14 deletions bigframes/ml/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,9 @@ def __init__(
self.output = output
self.session = session or bpd.get_global_session()

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bq_connection_manager = self.session.bqconnectionmanager
connection_name = connection_name or self.session._bq_connection
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
Expand All @@ -81,17 +79,19 @@ def _create_bqml_model(self):
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

options = {
"endpoint": self.endpoint,
Expand Down
Loading