Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support CMEK for BQ tables #403

Merged
merged 25 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
144190e
feat: Support CMEK for BQ tables
shobsi Feb 29, 2024
eda994e
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Feb 29, 2024
6b41ec0
add more tests
shobsi Mar 1, 2024
d2f3f3d
add unit tests
shobsi Mar 1, 2024
425cf12
add more tests, fix broken tests
shobsi Mar 1, 2024
2443d40
separate bqml client to send kms_key_name via OPTIONS instead of job
shobsi Mar 4, 2024
ca50e6c
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 4, 2024
07d97f4
fix unit tests
shobsi Mar 4, 2024
950dd27
fix mypy
shobsi Mar 4, 2024
ee717fe
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 4, 2024
927d8a2
skip cmek test for empty cmek
shobsi Mar 4, 2024
827bef2
move staticmethods to helper module
shobsi Mar 5, 2024
e1b3258
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 6, 2024
f1c8b00
revert bqmlclient, pass cmek through call time job config
shobsi Mar 7, 2024
28064f8
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 7, 2024
073f989
revert bqmlclient unit test
shobsi Mar 7, 2024
79a4b73
fix mypy failure
shobsi Mar 7, 2024
313bb12
use better named key, disable use_query_cache in test
shobsi Mar 7, 2024
9c8d064
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 7, 2024
4d4bb10
rename bqml create model internal method
shobsi Mar 8, 2024
0eff3a6
fix renamed methods's reference in unit tests
shobsi Mar 8, 2024
185edd9
remove stray bqmlclient variable
shobsi Mar 8, 2024
07dca56
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
86ebba7
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
2ae5361
Merge remote-tracking branch 'refs/remotes/github/main' into shobs-cmek
shobsi Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions bigframes/_config/bigquery_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(
bq_connection: Optional[str] = None,
use_regional_endpoints: bool = False,
application_name: Optional[str] = None,
kms_key_name: Optional[str] = None,
):
self._credentials = credentials
self._project = project
self._location = location
self._bq_connection = bq_connection
self._use_regional_endpoints = use_regional_endpoints
self._application_name = application_name
self._kms_key_name = kms_key_name
self._session_started = False

@property
Expand Down Expand Up @@ -148,3 +150,26 @@ def use_regional_endpoints(self, value: bool):
)

self._use_regional_endpoints = value

@property
def kms_key_name(self) -> Optional[str]:
"""Customer managed encryption key used to control encryption of the
data-at-rest in BigQuery. This is of the format
projects/PROJECT_ID/locations/LOCATION/keyRings/KEYRING/cryptoKeys/KEY

See https://cloud.google.com/bigquery/docs/customer-managed-encryption
for more details.

Please make sure the project used for Bigquery DataFrames has "Cloud KMS
CryptoKey Encrypter/Decrypter" role in the key's project, See
https://cloud.google.com/bigquery/docs/customer-managed-encryption#assign_role
for steps on how to ensure that.
"""
return self._kms_key_name

@kms_key_name.setter
def kms_key_name(self, value: str):
if self._session_started and self._kms_key_name != value:
raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="kms_key_name"))

self._kms_key_name = value
10 changes: 7 additions & 3 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def principal_component_info(self) -> bpd.DataFrame:
return self._session.read_gbq(sql)

def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel:
job_config = bigquery.job.CopyJobConfig()
job_config = self._session._prepare_copy_job_config()

if replace:
job_config.write_disposition = "WRITE_TRUNCATE"

Expand All @@ -236,7 +237,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
options={"vertex_ai_model_id": vertex_ai_model_id}
)
# Register the model and wait it to finish
self._session._start_query(sql)
self._session._start_query_bqml(sql)

self._model = self._session.bqclient.get_model(self.model_name)
return self
Expand All @@ -255,7 +256,7 @@ def _create_model_ref(

def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel:
# fit the model, synchronously
_, job = session._start_query(sql)
_, job = session._start_query_bqml(sql)

# real model path in the session specific hidden dataset and table prefix
model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}"
Expand Down Expand Up @@ -298,6 +299,9 @@ def create_model(
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})

session = X_train._session
if session._bq_kms_key_name:
options.update({"kms_key_name": session._bq_kms_key_name})

model_ref = self._create_model_ref(session._anonymous_dataset)

sql = self._model_creation_sql_generator.create_model(
Expand Down
1 change: 1 addition & 0 deletions bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _set_default_session_location_if_possible(query):
use_regional_endpoints=options.bigquery.use_regional_endpoints,
credentials=options.bigquery.credentials,
application_name=options.bigquery.application_name,
bq_kms_key_name=options.bigquery.kms_key_name,
)

bqclient = clients_provider.bqclient
Expand Down
125 changes: 86 additions & 39 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@

import bigframes._config.bigquery_options as bigquery_options
import bigframes.constants as constants
from bigframes.core import log_adapter
import bigframes.core as core
import bigframes.core.blocks as blocks
import bigframes.core.compile
Expand All @@ -84,7 +83,6 @@

# Even though the ibis.backends.bigquery import is unused, it's needed
# to register new and replacement ops with the Ibis BigQuery backend.
import third_party.bigframes_vendored.ibis.backends.bigquery # noqa
import third_party.bigframes_vendored.ibis.expr.operations as vendored_ibis_ops
import third_party.bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq
import third_party.bigframes_vendored.pandas.io.parquet as third_party_pandas_parquet
Expand Down Expand Up @@ -161,6 +159,8 @@ def __init__(
else:
self._location = context.location

self._bq_kms_key_name = context.kms_key_name

# Instantiate a clients provider to help with cloud clients that will be
# used in the future operations in the session
if clients_provider:
Expand All @@ -172,9 +172,17 @@ def __init__(
use_regional_endpoints=context.use_regional_endpoints,
credentials=context.credentials,
application_name=context.application_name,
bq_kms_key_name=self._bq_kms_key_name,
)

self._create_bq_datasets()

# TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494
# has been fixed. The ibis client changes the default query job config
# so we are going to remember the current config and restore it after
# the ibis client has been created
original_default_query_job_config = self.bqclient.default_query_job_config

self.ibis_client = typing.cast(
ibis_bigquery.Backend,
ibis.bigquery.connect(
Expand All @@ -184,6 +192,9 @@ def __init__(
),
)

self.bqclient.default_query_job_config = original_default_query_job_config

# Resolve the BQ connection for remote function and Vertex AI integration
self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID

# Now that we're starting the session, don't allow the options to be
Expand Down Expand Up @@ -929,19 +940,21 @@ def _read_pandas_load_job(
pandas_dataframe_copy.columns = pandas.Index(new_col_ids)
pandas_dataframe_copy[ordering_col] = np.arange(pandas_dataframe_copy.shape[0])

job_config = self._prepare_load_job_config()

# Specify the datetime dtypes, which is auto-detected as timestamp types.
schema: list[bigquery.SchemaField] = []
for column, dtype in zip(pandas_dataframe.columns, pandas_dataframe.dtypes):
if dtype == "timestamp[us][pyarrow]":
schema.append(
bigquery.SchemaField(column, bigquery.enums.SqlTypeNames.DATETIME)
)
job_config.schema = schema

# Clustering probably not needed anyways as pandas tables are small
cluster_cols = [ordering_col]

job_config = bigquery.LoadJobConfig(schema=schema)
job_config.clustering_fields = cluster_cols

job_config.labels = {"bigframes-api": api_name}

load_table_destination = bigframes_io.random_table(self._anonymous_dataset)
Expand Down Expand Up @@ -1061,7 +1074,7 @@ def read_csv(
f"{constants.FEEDBACK_LINK}"
)

job_config = bigquery.LoadJobConfig()
job_config = self._prepare_load_job_config()
job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED
job_config.source_format = bigquery.SourceFormat.CSV
job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY
Expand Down Expand Up @@ -1136,7 +1149,7 @@ def read_parquet(
table = bigframes_io.random_table(self._anonymous_dataset)

if engine == "bigquery":
job_config = bigquery.LoadJobConfig()
job_config = self._prepare_load_job_config()
job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY
Expand Down Expand Up @@ -1194,7 +1207,7 @@ def read_json(
"'lines' keyword is only valid when 'orient' is 'records'."
)

job_config = bigquery.LoadJobConfig()
job_config = self._prepare_load_job_config()
job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED
job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON
job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY
Expand Down Expand Up @@ -1518,36 +1531,83 @@ def read_gbq_function(
session=self,
)

def _prepare_query_job_config(
self,
job_config: Optional[bigquery.QueryJobConfig] = None,
) -> bigquery.QueryJobConfig:
if job_config is None:
job_config = bigquery.QueryJobConfig()
else:
# Create a copy so that we don't mutate the original config passed
job_config = typing.cast(
bigquery.QueryJobConfig,
bigquery.QueryJobConfig.from_api_repr(job_config.to_api_repr()),
)

if bigframes.options.compute.maximum_bytes_billed is not None:
job_config.maximum_bytes_billed = (
bigframes.options.compute.maximum_bytes_billed
)

if self._bq_kms_key_name:
job_config.destination_encryption_configuration = (
bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name)
)

return job_config

def _prepare_load_job_config(self) -> bigquery.LoadJobConfig:
# Create a copy so that we don't mutate the original config passed
job_config = bigquery.LoadJobConfig()

if self._bq_kms_key_name:
job_config.destination_encryption_configuration = (
bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name)
)

return job_config

def _prepare_copy_job_config(self) -> bigquery.CopyJobConfig:
# Create a copy so that we don't mutate the original config passed
job_config = bigquery.CopyJobConfig()

if self._bq_kms_key_name:
job_config.destination_encryption_configuration = (
bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name)
)

return job_config

def _start_query(
self,
sql: str,
job_config: Optional[bigquery.job.QueryJobConfig] = None,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts query job and waits for results.
Starts BigQuery query job and waits for results.
"""
job_config = self._prepare_job_config(job_config)
api_methods = log_adapter.get_and_reset_api_methods()
job_config.labels = bigframes_io.create_job_configs_labels(
job_configs_labels=job_config.labels, api_methods=api_methods
job_config = self._prepare_query_job_config(job_config)
return bigframes.session._io.bigquery.start_query_with_client(
self.bqclient, sql, job_config, max_results
)

try:
query_job = self.bqclient.query(sql, job_config=job_config)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions."
raise
def _start_query_bqml(
shobsi marked this conversation as resolved.
Show resolved Hide resolved
self,
sql: str,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts BigQuery ML query job and waits for results.
"""
job_config = self._prepare_query_job_config()

opts = bigframes.options.display
if opts.progress_bar is not None and not query_job.configuration.dry_run:
results_iterator = formatting_helpers.wait_for_query_job(
query_job, max_results, opts.progress_bar
)
else:
results_iterator = query_job.result(max_results=max_results)
return results_iterator, query_job
# BQML expects kms_key_name through OPTIONS and not through job config,
# so we must reset any encryption set in the job config
job_config.destination_encryption_configuration = None

return bigframes.session._io.bigquery.start_query_with_client(
self.bqclient, sql, job_config
)

def _cache_with_cluster_cols(
self, array_value: core.ArrayValue, cluster_cols: typing.Sequence[str]
Expand Down Expand Up @@ -1690,19 +1750,6 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob):
else:
job.result()

def _prepare_job_config(
self, job_config: Optional[bigquery.QueryJobConfig] = None
) -> bigquery.QueryJobConfig:
if job_config is None:
job_config = self.bqclient.default_query_job_config
if job_config is None:
job_config = bigquery.QueryJobConfig()
if bigframes.options.compute.maximum_bytes_billed is not None:
job_config.maximum_bytes_billed = (
bigframes.options.compute.maximum_bytes_billed
)
return job_config


def connect(context: Optional[bigquery_options.BigQueryOptions] = None) -> Session:
return Session(context)
Expand Down
39 changes: 38 additions & 1 deletion bigframes/session/_io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@
import itertools
import textwrap
import types
from typing import Dict, Iterable, Optional, Sequence, Union
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
import uuid

import google.api_core.exceptions
import google.cloud.bigquery as bigquery

import bigframes
from bigframes.core import log_adapter
import bigframes.formatting_helpers as formatting_helpers
import bigframes.session._io.bigquery as bigframes_io

IO_ORDERING_ID = "bqdf_row_nums"
MAX_LABELS_COUNT = 64
TEMP_TABLE_PREFIX = "bqdf{date}_{random_id}"
Expand Down Expand Up @@ -207,3 +213,34 @@ def format_option(key: str, value: Union[bool, str]) -> str:
if isinstance(value, bool):
return f"{key}=true" if value else f"{key}=false"
return f"{key}={repr(value)}"


def start_query_with_client(
bq_client: bigquery.Client,
sql: str,
job_config: bigquery.job.QueryJobConfig,
max_results: Optional[int] = None,
) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
"""
Starts query job and waits for results.
"""
api_methods = log_adapter.get_and_reset_api_methods()
job_config.labels = bigframes_io.create_job_configs_labels(
job_configs_labels=job_config.labels, api_methods=api_methods
)

try:
query_job = bq_client.query(sql, job_config=job_config)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions."
raise

opts = bigframes.options.display
if opts.progress_bar is not None and not query_job.configuration.dry_run:
results_iterator = formatting_helpers.wait_for_query_job(
query_job, max_results, opts.progress_bar
)
else:
results_iterator = query_job.result(max_results=max_results)
return results_iterator, query_job
Loading