Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support CMEK for BQ tables #403

Merged
merged 25 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
144190e
feat: Support CMEK for BQ tables
shobsi Feb 29, 2024
eda994e
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Feb 29, 2024
6b41ec0
add more tests
shobsi Mar 1, 2024
d2f3f3d
add unit tests
shobsi Mar 1, 2024
425cf12
add more tests, fix broken tests
shobsi Mar 1, 2024
2443d40
separate bqml client to send kms_key_name via OPTIONS instead of job
shobsi Mar 4, 2024
ca50e6c
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 4, 2024
07d97f4
fix unit tests
shobsi Mar 4, 2024
950dd27
fix mypy
shobsi Mar 4, 2024
ee717fe
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 4, 2024
927d8a2
skip cmek test for empty cmek
shobsi Mar 4, 2024
827bef2
move staticmethods to helper module
shobsi Mar 5, 2024
e1b3258
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 6, 2024
f1c8b00
revert bqmlclient, pass cmek through call time job config
shobsi Mar 7, 2024
28064f8
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 7, 2024
073f989
revert bqmlclient unit test
shobsi Mar 7, 2024
79a4b73
fix mypy failure
shobsi Mar 7, 2024
313bb12
use better named key, disable use_query_cache in test
shobsi Mar 7, 2024
9c8d064
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 7, 2024
4d4bb10
rename bqml create model internal method
shobsi Mar 8, 2024
0eff3a6
fix renamed methods's reference in unit tests
shobsi Mar 8, 2024
185edd9
remove stray bqmlclient variable
shobsi Mar 8, 2024
07dca56
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
86ebba7
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
2ae5361
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(
bq_connection: Optional[str] = None,
use_regional_endpoints: bool = False,
application_name: Optional[str] = None,
kms_key_name: Optional[str] = None,
):
self._credentials = credentials
self._project = project
self._location = location
self._bq_connection = bq_connection
self._use_regional_endpoints = use_regional_endpoints
self._application_name = application_name
self._kms_key_name = kms_key_name
self._session_started = False

@property
Expand Down Expand Up @@ -148,3 +150,26 @@ def use_regional_endpoints(self, value: bool):
)

self._use_regional_endpoints = value

@property
def kms_key_name(self) -> Optional[str]:
"""Customer managed encryption key used to control encryption of the
data-at-rest in BigQuery. This is of the format
projects/PROJECT_ID/locations/LOCATION/keyRings/KEYRING/cryptoKeys/KEY

See https://cloud.google.com/bigquery/docs/customer-managed-encryption
for more details.

Please make sure the project used for Bigquery DataFrames has "Cloud KMS
CryptoKey Encrypter/Decrypter" role in the key's project, See
https://cloud.google.com/bigquery/docs/customer-managed-encryption#assign_role
for steps on how to ensure that.
"""
return self._kms_key_name

@kms_key_name.setter
def kms_key_name(self, value: str):
if self._session_started and self._kms_key_name != value:
raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="kms_key_name"))

self._kms_key_name = value
13 changes: 10 additions & 3 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ def principal_component_info(self) -> bpd.DataFrame:
return self._session.read_gbq(sql)

def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
job_config = bigquery.job.CopyJobConfig()
job_config = bigquery.job.CopyJobConfig(
destination_encryption_configuration=bigquery.EncryptionConfiguration(
kms_key_name=self._session._bq_kms_key_name
)
)
if replace:
job_config.write_disposition = "WRITE_TRUNCATE"

Expand All @@ -236,7 +240,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
options={"vertex_ai_model_id": vertex_ai_model_id}
)
# Register the model and wait it to finish
self._session._start_query(sql)
self._session._start_query_bqml(sql)

self._model = self._session.bqclient.get_model(self.model_name)
return self
Expand All @@ -255,7 +259,7 @@ def _create_model_ref(

def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel:
# fit the model, synchronously
_, job = session._start_query(sql)
_, job = session._start_query_bqml(sql)

# real model path in the session specific hidden dataset and table prefix
model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}"
Expand Down Expand Up @@ -298,6 +302,9 @@ def create_model(
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})

session = X_train._session
if session._bq_kms_key_name:
options.update({"kms_key_name": session._bq_kms_key_name})

model_ref = self._create_model_ref(session._anonymous_dataset)

sql = self._model_creation_sql_generator.create_model(
Expand Down
1 change: 1 addition & 0 deletions bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _set_default_session_location_if_possible(query):
use_regional_endpoints=options.bigquery.use_regional_endpoints,
credentials=options.bigquery.credentials,
application_name=options.bigquery.application_name,
bq_kms_key_name=options.bigquery.kms_key_name,
)

bqclient = clients_provider.bqclient
Expand Down
57 changes: 50 additions & 7 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def __init__(
else:
self._location = context.location

self._bq_kms_key_name = context.kms_key_name

# Instantiate a clients provider to help with cloud clients that will be
# used in the future operations in the session
if clients_provider:
Expand All @@ -172,9 +174,17 @@ def __init__(
use_regional_endpoints=context.use_regional_endpoints,
credentials=context.credentials,
application_name=context.application_name,
bq_kms_key_name=self._bq_kms_key_name,
)

self._create_bq_datasets()

# TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494
# has been fixed. The ibis client changes the default query job config
# so we are going to remember the current config and restore it after
# the ibis client has been created
original_default_query_job_config = self.bqclient.default_query_job_config

self.ibis_client = typing.cast(
ibis_bigquery.Backend,
ibis.bigquery.connect(
Expand All @@ -184,6 +194,9 @@ def __init__(
),
)

self.bqclient.default_query_job_config = original_default_query_job_config

# Resolve the BQ connection for remote function and Vertex AI integration
self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID

# Now that we're starting the session, don't allow the options to be
Expand All @@ -195,6 +208,10 @@ def __init__(
def bqclient(self):
return self._clients_provider.bqclient

@property
def bqmlclient(self):
return self._clients_provider.bqmlclient

@property
def bqconnectionclient(self):
return self._clients_provider.bqconnectionclient
Expand Down Expand Up @@ -1496,23 +1513,24 @@ def read_gbq_function(
session=self,
)

def _start_query(
self,
@staticmethod
shobsi marked this conversation as resolved.
Show resolved Hide resolved
def _start_query_with_client(
bq_client: bigquery.Client,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts query job and waits for results.
"""
job_config = self._prepare_job_config(job_config)
job_config = Session._prepare_job_config(job_config)
api_methods = log_adapter.get_and_reset_api_methods()
job_config.labels = bigframes_io.create_job_configs_labels(
job_configs_labels=job_config.labels, api_methods=api_methods
)

try:
query_job = self.bqclient.query(sql, job_config=job_config)
query_job = bq_client.query(sql, job_config=job_config)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions."
Expand All @@ -1527,6 +1545,32 @@ def _start_query(
results_iterator = query_job.result(max_results=max_results)
return results_iterator, query_job

def _start_query(
self,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts BigQuery query job and waits for results.
"""
return Session._start_query_with_client(
self.bqclient, sql, job_config, max_results
)

def _start_query_bqml(
shobsi marked this conversation as resolved.
Show resolved Hide resolved
self,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts BigQuery ML query job and waits for results.
"""
return Session._start_query_with_client(
self.bqmlclient, sql, job_config, max_results
)

def _cache_with_cluster_cols(
self, array_value: core.ArrayValue, cluster_cols: typing.Sequence[str]
) -> core.ArrayValue:
Expand Down Expand Up @@ -1668,11 +1712,10 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob):
else:
job.result()

@staticmethod
shobsi marked this conversation as resolved.
Show resolved Hide resolved
def _prepare_job_config(
self, job_config: Optional[bigquery.QueryJobConfig] = None
job_config: Optional[bigquery.QueryJobConfig] = None,
) -> bigquery.QueryJobConfig:
if job_config is None:
job_config = self.bqclient.default_query_job_config
if job_config is None:
job_config = bigquery.QueryJobConfig()
if bigframes.options.compute.maximum_bytes_billed is not None:
Expand Down
63 changes: 45 additions & 18 deletions bigframes/session/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
use_regional_endpoints: Optional[bool],
credentials: Optional[google.auth.credentials.Credentials],
application_name: Optional[str],
bq_kms_key_name: Optional[str],
):
credentials_project = None
if credentials is None:
Expand Down Expand Up @@ -98,39 +99,65 @@ def __init__(
self._location = location
self._use_regional_endpoints = use_regional_endpoints
self._credentials = credentials
self._bq_kms_key_name = bq_kms_key_name

# cloud clients initialized for lazy load
self._bqclient = None
self._bqmlclient = None
shobsi marked this conversation as resolved.
Show resolved Hide resolved
self._bqconnectionclient = None
self._bqstoragereadclient = None
self._cloudfunctionsclient = None
self._resourcemanagerclient = None

def _create_bigquery_client(self):
bq_options = None
if self._use_regional_endpoints:
bq_options = google.api_core.client_options.ClientOptions(
api_endpoint=(
_BIGQUERY_REGIONAL_ENDPOINT
if self._location.lower() in _REP_SUPPORTED_REGIONS
else _BIGQUERY_LOCATIONAL_ENDPOINT
).format(location=self._location),
)
bq_info = google.api_core.client_info.ClientInfo(
user_agent=self._application_name
)

bq_client = bigquery.Client(
client_info=bq_info,
client_options=bq_options,
credentials=self._credentials,
project=self._project,
location=self._location,
)

return bq_client

@property
def bqclient(self):
if not self._bqclient:
bq_options = None
if self._use_regional_endpoints:
bq_options = google.api_core.client_options.ClientOptions(
api_endpoint=(
_BIGQUERY_REGIONAL_ENDPOINT
if self._location.lower() in _REP_SUPPORTED_REGIONS
else _BIGQUERY_LOCATIONAL_ENDPOINT
).format(location=self._location),
self._bqclient = self._create_bigquery_client()
if self._bq_kms_key_name:
self._bqclient.default_query_job_config = bigquery.QueryJobConfig(
destination_encryption_configuration=bigquery.EncryptionConfiguration(
kms_key_name=self._bq_kms_key_name
)
)
self._bqclient.default_load_job_config = bigquery.LoadJobConfig(
destination_encryption_configuration=bigquery.EncryptionConfiguration(
kms_key_name=self._bq_kms_key_name
)
)
bq_info = google.api_core.client_info.ClientInfo(
user_agent=self._application_name
)
self._bqclient = bigquery.Client(
client_info=bq_info,
client_options=bq_options,
credentials=self._credentials,
project=self._project,
location=self._location,
)

return self._bqclient

@property
def bqmlclient(self):
shobsi marked this conversation as resolved.
Show resolved Hide resolved
if not self._bqmlclient:
self._bqmlclient = self._create_bigquery_client()

return self._bqmlclient

@property
def bqconnectionclient(self):
if not self._bqconnectionclient:
Expand Down
Loading