Skip to content

Commit

Permalink
feat: Support CMEK for BQ tables (#403)
Browse files Browse the repository at this point in the history
* feat: Support CMEK for BQ tables

* add more tests

* add unit tests

* add more tests, fix broken tests

* separate bqml client to send kms_key_name via OPTIONS instead of job
config

* fix unit tests

* fix mypy

* skip cmek test for empty cmek

* move staticmethods to helper module

* revert bqmlclient, pass cmek through call time job config

* revert bqmlclient unit test

* fix mypy failure

* use better named key, disable use_query_cache in test

* rename bqml create model internal method

* fix renamed methods's reference in unit tests

* remove stray bqmlclient variable
  • Loading branch information
shobsi authored Mar 9, 2024
1 parent 815f578 commit 9a678e3
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 67 deletions.
25 changes: 25 additions & 0 deletions bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(
bq_connection: Optional[str] = None,
use_regional_endpoints: bool = False,
application_name: Optional[str] = None,
kms_key_name: Optional[str] = None,
):
self._credentials = credentials
self._project = project
self._location = location
self._bq_connection = bq_connection
self._use_regional_endpoints = use_regional_endpoints
self._application_name = application_name
self._kms_key_name = kms_key_name
self._session_started = False

@property
Expand Down Expand Up @@ -148,3 +150,26 @@ def use_regional_endpoints(self, value: bool):
)

self._use_regional_endpoints = value

@property
def kms_key_name(self) -> Optional[str]:
"""Customer managed encryption key used to control encryption of the
data-at-rest in BigQuery. This is of the format
projects/PROJECT_ID/locations/LOCATION/keyRings/KEYRING/cryptoKeys/KEY
See https://cloud.google.com/bigquery/docs/customer-managed-encryption
for more details.
Please make sure the project used for Bigquery DataFrames has "Cloud KMS
CryptoKey Encrypter/Decrypter" role in the key's project, See
https://cloud.google.com/bigquery/docs/customer-managed-encryption#assign_role
for steps on how to ensure that.
"""
return self._kms_key_name

@kms_key_name.setter
def kms_key_name(self, value: str):
if self._session_started and self._kms_key_name != value:
raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="kms_key_name"))

self._kms_key_name = value
10 changes: 7 additions & 3 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def principal_component_info(self) -> bpd.DataFrame:
return self._session.read_gbq(sql)

def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
job_config = bigquery.job.CopyJobConfig()
job_config = self._session._prepare_copy_job_config()

if replace:
job_config.write_disposition = "WRITE_TRUNCATE"

Expand All @@ -236,7 +237,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
options={"vertex_ai_model_id": vertex_ai_model_id}
)
# Register the model and wait it to finish
self._session._start_query(sql)
self._session._start_query_create_model(sql)

self._model = self._session.bqclient.get_model(self.model_name)
return self
Expand All @@ -255,7 +256,7 @@ def _create_model_ref(

def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel:
# fit the model, synchronously
_, job = session._start_query(sql)
_, job = session._start_query_create_model(sql)

# real model path in the session specific hidden dataset and table prefix
model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}"
Expand Down Expand Up @@ -298,6 +299,9 @@ def create_model(
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})

session = X_train._session
if session._bq_kms_key_name:
options.update({"kms_key_name": session._bq_kms_key_name})

model_ref = self._create_model_ref(session._anonymous_dataset)

sql = self._model_creation_sql_generator.create_model(
Expand Down
1 change: 1 addition & 0 deletions bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _set_default_session_location_if_possible(query):
use_regional_endpoints=options.bigquery.use_regional_endpoints,
credentials=options.bigquery.credentials,
application_name=options.bigquery.application_name,
bq_kms_key_name=options.bigquery.kms_key_name,
)

bqclient = clients_provider.bqclient
Expand Down
126 changes: 87 additions & 39 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@

import bigframes._config.bigquery_options as bigquery_options
import bigframes.constants as constants
from bigframes.core import log_adapter
import bigframes.core as core
import bigframes.core.blocks as blocks
import bigframes.core.compile
Expand All @@ -84,7 +83,6 @@

# Even though the ibis.backends.bigquery import is unused, it's needed
# to register new and replacement ops with the Ibis BigQuery backend.
import third_party.bigframes_vendored.ibis.backends.bigquery # noqa
import third_party.bigframes_vendored.ibis.expr.operations as vendored_ibis_ops
import third_party.bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq
import third_party.bigframes_vendored.pandas.io.parquet as third_party_pandas_parquet
Expand Down Expand Up @@ -161,6 +159,8 @@ def __init__(
else:
self._location = context.location

self._bq_kms_key_name = context.kms_key_name

# Instantiate a clients provider to help with cloud clients that will be
# used in the future operations in the session
if clients_provider:
Expand All @@ -172,9 +172,17 @@ def __init__(
use_regional_endpoints=context.use_regional_endpoints,
credentials=context.credentials,
application_name=context.application_name,
bq_kms_key_name=self._bq_kms_key_name,
)

self._create_bq_datasets()

# TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494
# has been fixed. The ibis client changes the default query job config
# so we are going to remember the current config and restore it after
# the ibis client has been created
original_default_query_job_config = self.bqclient.default_query_job_config

self.ibis_client = typing.cast(
ibis_bigquery.Backend,
ibis.bigquery.connect(
Expand All @@ -184,6 +192,9 @@ def __init__(
),
)

self.bqclient.default_query_job_config = original_default_query_job_config

# Resolve the BQ connection for remote function and Vertex AI integration
self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID

# Now that we're starting the session, don't allow the options to be
Expand Down Expand Up @@ -929,19 +940,21 @@ def _read_pandas_load_job(
pandas_dataframe_copy.columns = pandas.Index(new_col_ids)
pandas_dataframe_copy[ordering_col] = np.arange(pandas_dataframe_copy.shape[0])

job_config = self._prepare_load_job_config()

# Specify the datetime dtypes, which is auto-detected as timestamp types.
schema: list[bigquery.SchemaField] = []
for column, dtype in zip(pandas_dataframe.columns, pandas_dataframe.dtypes):
if dtype == "timestamp[us][pyarrow]":
schema.append(
bigquery.SchemaField(column, bigquery.enums.SqlTypeNames.DATETIME)
)
job_config.schema = schema

# Clustering probably not needed anyways as pandas tables are small
cluster_cols = [ordering_col]

job_config = bigquery.LoadJobConfig(schema=schema)
job_config.clustering_fields = cluster_cols

job_config.labels = {"bigframes-api": api_name}

load_table_destination = bigframes_io.random_table(self._anonymous_dataset)
Expand Down Expand Up @@ -1061,7 +1074,7 @@ def read_csv(
f"{constants.FEEDBACK_LINK}"
)

job_config = bigquery.LoadJobConfig()
job_config = self._prepare_load_job_config()
job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED
job_config.source_format = bigquery.SourceFormat.CSV
job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY
Expand Down Expand Up @@ -1136,7 +1149,7 @@ def read_parquet(
table = bigframes_io.random_table(self._anonymous_dataset)

if engine == "bigquery":
job_config = bigquery.LoadJobConfig()
job_config = self._prepare_load_job_config()
job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY
Expand Down Expand Up @@ -1194,7 +1207,7 @@ def read_json(
"'lines' keyword is only valid when 'orient' is 'records'."
)

job_config = bigquery.LoadJobConfig()
job_config = self._prepare_load_job_config()
job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED
job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON
job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY
Expand Down Expand Up @@ -1518,36 +1531,84 @@ def read_gbq_function(
session=self,
)

def _prepare_query_job_config(
self,
job_config: Optional[bigquery.QueryJobConfig] = None,
) -> bigquery.QueryJobConfig:
if job_config is None:
job_config = bigquery.QueryJobConfig()
else:
# Create a copy so that we don't mutate the original config passed
job_config = typing.cast(
bigquery.QueryJobConfig,
bigquery.QueryJobConfig.from_api_repr(job_config.to_api_repr()),
)

if bigframes.options.compute.maximum_bytes_billed is not None:
job_config.maximum_bytes_billed = (
bigframes.options.compute.maximum_bytes_billed
)

if self._bq_kms_key_name:
job_config.destination_encryption_configuration = (
bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name)
)

return job_config

def _prepare_load_job_config(self) -> bigquery.LoadJobConfig:
# Create a copy so that we don't mutate the original config passed
job_config = bigquery.LoadJobConfig()

if self._bq_kms_key_name:
job_config.destination_encryption_configuration = (
bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name)
)

return job_config

def _prepare_copy_job_config(self) -> bigquery.CopyJobConfig:
# Create a copy so that we don't mutate the original config passed
job_config = bigquery.CopyJobConfig()

if self._bq_kms_key_name:
job_config.destination_encryption_configuration = (
bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name)
)

return job_config

def _start_query(
self,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts query job and waits for results.
Starts BigQuery query job and waits for results.
"""
job_config = self._prepare_job_config(job_config)
api_methods = log_adapter.get_and_reset_api_methods()
job_config.labels = bigframes_io.create_job_configs_labels(
job_configs_labels=job_config.labels, api_methods=api_methods
job_config = self._prepare_query_job_config(job_config)
return bigframes.session._io.bigquery.start_query_with_client(
self.bqclient, sql, job_config, max_results
)

try:
query_job = self.bqclient.query(sql, job_config=job_config)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions."
raise
def _start_query_create_model(
self,
sql: str,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts BigQuery ML CREATE MODEL query job and waits for results.
"""
job_config = self._prepare_query_job_config()

opts = bigframes.options.display
if opts.progress_bar is not None and not query_job.configuration.dry_run:
results_iterator = formatting_helpers.wait_for_query_job(
query_job, max_results, opts.progress_bar
)
else:
results_iterator = query_job.result(max_results=max_results)
return results_iterator, query_job
# BQML expects kms_key_name through OPTIONS and not through job config,
# so we must reset any encryption set in the job config
# https://cloud.google.com/bigquery/docs/customer-managed-encryption#encrypt-model
job_config.destination_encryption_configuration = None

return bigframes.session._io.bigquery.start_query_with_client(
self.bqclient, sql, job_config
)

def _cache_with_cluster_cols(
self, array_value: core.ArrayValue, cluster_cols: typing.Sequence[str]
Expand Down Expand Up @@ -1696,19 +1757,6 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob):
else:
job.result()

def _prepare_job_config(
self, job_config: Optional[bigquery.QueryJobConfig] = None
) -> bigquery.QueryJobConfig:
if job_config is None:
job_config = self.bqclient.default_query_job_config
if job_config is None:
job_config = bigquery.QueryJobConfig()
if bigframes.options.compute.maximum_bytes_billed is not None:
job_config.maximum_bytes_billed = (
bigframes.options.compute.maximum_bytes_billed
)
return job_config


def connect(context: Optional[bigquery_options.BigQueryOptions] = None) -> Session:
return Session(context)
Expand Down
39 changes: 38 additions & 1 deletion bigframes/session/_io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
import itertools
import textwrap
import types
from typing import Dict, Iterable, Optional, Sequence, Union
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
import uuid

import google.api_core.exceptions
import google.cloud.bigquery as bigquery

import bigframes
from bigframes.core import log_adapter
import bigframes.formatting_helpers as formatting_helpers
import bigframes.session._io.bigquery as bigframes_io

IO_ORDERING_ID = "bqdf_row_nums"
MAX_LABELS_COUNT = 64
TEMP_TABLE_PREFIX = "bqdf{date}_{random_id}"
Expand Down Expand Up @@ -207,3 +213,34 @@ def format_option(key: str, value: Union[bool, str]) -> str:
if isinstance(value, bool):
return f"{key}=true" if value else f"{key}=false"
return f"{key}={repr(value)}"


def start_query_with_client(
bq_client: bigquery.Client,
sql: str,
job_config: bigquery.job.QueryJobConfig,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts query job and waits for results.
"""
api_methods = log_adapter.get_and_reset_api_methods()
job_config.labels = bigframes_io.create_job_configs_labels(
job_configs_labels=job_config.labels, api_methods=api_methods
)

try:
query_job = bq_client.query(sql, job_config=job_config)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions."
raise

opts = bigframes.options.display
if opts.progress_bar is not None and not query_job.configuration.dry_run:
results_iterator = formatting_helpers.wait_for_query_job(
query_job, max_results, opts.progress_bar
)
else:
results_iterator = query_job.result(max_results=max_results)
return results_iterator, query_job
Loading

0 comments on commit 9a678e3

Please sign in to comment.