From 9a678e35201d935e1d93875429005033cfe7cff6 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Sat, 9 Mar 2024 02:14:50 +0000 Subject: [PATCH] feat: Support CMEK for BQ tables (#403) * feat: Support CMEK for BQ tables * add more tests * add unit tests * add more tests, fix broken tests * separate bqml client to send kms_key_name via OPTIONS instead of job config * fix unit tests * fix mypy * skip cmek test for empty cmek * move staticmethods to helper module * revert bqmlclient, pass cmek through call time job config * revert bqmlclient unit test * fix mypy failure * use better named key, disable use_query_cache in test * rename bqml create model internal method * fix renamed methods's reference in unit tests * remove stray bqmlclient variable --- bigframes/_config/bigquery_options.py | 25 ++ bigframes/ml/core.py | 10 +- bigframes/pandas/__init__.py | 1 + bigframes/session/__init__.py | 126 +++++++--- bigframes/session/_io/bigquery.py | 39 ++- bigframes/session/clients.py | 46 ++-- tests/system/small/test_encryption.py | 256 ++++++++++++++++++++ tests/unit/_config/test_bigquery_options.py | 2 + tests/unit/ml/test_golden_sql.py | 11 +- tests/unit/session/test_clients.py | 1 + 10 files changed, 450 insertions(+), 67 deletions(-) create mode 100644 tests/system/small/test_encryption.py diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index 74b83429d0..34701740f6 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -39,6 +39,7 @@ def __init__( bq_connection: Optional[str] = None, use_regional_endpoints: bool = False, application_name: Optional[str] = None, + kms_key_name: Optional[str] = None, ): self._credentials = credentials self._project = project @@ -46,6 +47,7 @@ def __init__( self._bq_connection = bq_connection self._use_regional_endpoints = use_regional_endpoints self._application_name = application_name + self._kms_key_name = kms_key_name self._session_started = False @property @@ -148,3 +150,26 @@ def use_regional_endpoints(self, value: bool): ) self._use_regional_endpoints = value + + @property + def kms_key_name(self) -> Optional[str]: + """Customer managed encryption key used to control encryption of the + data-at-rest in BigQuery. This is of the format + projects/PROJECT_ID/locations/LOCATION/keyRings/KEYRING/cryptoKeys/KEY + + See https://cloud.google.com/bigquery/docs/customer-managed-encryption + for more details. + + Please make sure the project used for Bigquery DataFrames has "Cloud KMS + CryptoKey Encrypter/Decrypter" role in the key's project, See + https://cloud.google.com/bigquery/docs/customer-managed-encryption#assign_role + for steps on how to ensure that. + """ + return self._kms_key_name + + @kms_key_name.setter + def kms_key_name(self, value: str): + if self._session_started and self._kms_key_name != value: + raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="kms_key_name")) + + self._kms_key_name = value diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index c496133aa7..24997708fb 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -212,7 +212,8 @@ def principal_component_info(self) -> bpd.DataFrame: return self._session.read_gbq(sql) def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel: - job_config = bigquery.job.CopyJobConfig() + job_config = self._session._prepare_copy_job_config() + if replace: job_config.write_disposition = "WRITE_TRUNCATE" @@ -236,7 +237,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel: options={"vertex_ai_model_id": vertex_ai_model_id} ) # Register the model and wait it to finish - self._session._start_query(sql) + self._session._start_query_create_model(sql) self._model = self._session.bqclient.get_model(self.model_name) return self @@ -255,7 +256,7 @@ def _create_model_ref( def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel: # fit the model, synchronously - _, job = session._start_query(sql) + _, job = session._start_query_create_model(sql) # real model path in the session specific hidden dataset and table prefix model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}" @@ -298,6 +299,9 @@ def create_model( options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()}) session = X_train._session + if session._bq_kms_key_name: + options.update({"kms_key_name": session._bq_kms_key_name}) + model_ref = self._create_model_ref(session._anonymous_dataset) sql = self._model_creation_sql_generator.create_model( diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index 3120e96b1a..195d7eabfa 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -383,6 +383,7 @@ def _set_default_session_location_if_possible(query): use_regional_endpoints=options.bigquery.use_regional_endpoints, credentials=options.bigquery.credentials, application_name=options.bigquery.application_name, + bq_kms_key_name=options.bigquery.kms_key_name, ) bqclient = clients_provider.bqclient diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 190ce17ee1..b553865ea9 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -65,7 +65,6 @@ import bigframes._config.bigquery_options as bigquery_options import bigframes.constants as constants -from bigframes.core import log_adapter import bigframes.core as core import bigframes.core.blocks as blocks import bigframes.core.compile @@ -84,7 +83,6 @@ # Even though the ibis.backends.bigquery import is unused, it's needed # to register new and replacement ops with the Ibis BigQuery backend. -import third_party.bigframes_vendored.ibis.backends.bigquery # noqa import third_party.bigframes_vendored.ibis.expr.operations as vendored_ibis_ops import third_party.bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq import third_party.bigframes_vendored.pandas.io.parquet as third_party_pandas_parquet @@ -161,6 +159,8 @@ def __init__( else: self._location = context.location + self._bq_kms_key_name = context.kms_key_name + # Instantiate a clients provider to help with cloud clients that will be # used in the future operations in the session if clients_provider: @@ -172,9 +172,17 @@ def __init__( use_regional_endpoints=context.use_regional_endpoints, credentials=context.credentials, application_name=context.application_name, + bq_kms_key_name=self._bq_kms_key_name, ) self._create_bq_datasets() + + # TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494 + # has been fixed. The ibis client changes the default query job config + # so we are going to remember the current config and restore it after + # the ibis client has been created + original_default_query_job_config = self.bqclient.default_query_job_config + self.ibis_client = typing.cast( ibis_bigquery.Backend, ibis.bigquery.connect( @@ -184,6 +192,9 @@ def __init__( ), ) + self.bqclient.default_query_job_config = original_default_query_job_config + + # Resolve the BQ connection for remote function and Vertex AI integration self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID # Now that we're starting the session, don't allow the options to be @@ -929,6 +940,8 @@ def _read_pandas_load_job( pandas_dataframe_copy.columns = pandas.Index(new_col_ids) pandas_dataframe_copy[ordering_col] = np.arange(pandas_dataframe_copy.shape[0]) + job_config = self._prepare_load_job_config() + # Specify the datetime dtypes, which is auto-detected as timestamp types. schema: list[bigquery.SchemaField] = [] for column, dtype in zip(pandas_dataframe.columns, pandas_dataframe.dtypes): @@ -936,12 +949,12 @@ def _read_pandas_load_job( schema.append( bigquery.SchemaField(column, bigquery.enums.SqlTypeNames.DATETIME) ) + job_config.schema = schema # Clustering probably not needed anyways as pandas tables are small cluster_cols = [ordering_col] - - job_config = bigquery.LoadJobConfig(schema=schema) job_config.clustering_fields = cluster_cols + job_config.labels = {"bigframes-api": api_name} load_table_destination = bigframes_io.random_table(self._anonymous_dataset) @@ -1061,7 +1074,7 @@ def read_csv( f"{constants.FEEDBACK_LINK}" ) - job_config = bigquery.LoadJobConfig() + job_config = self._prepare_load_job_config() job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED job_config.source_format = bigquery.SourceFormat.CSV job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY @@ -1136,7 +1149,7 @@ def read_parquet( table = bigframes_io.random_table(self._anonymous_dataset) if engine == "bigquery": - job_config = bigquery.LoadJobConfig() + job_config = self._prepare_load_job_config() job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED job_config.source_format = bigquery.SourceFormat.PARQUET job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY @@ -1194,7 +1207,7 @@ def read_json( "'lines' keyword is only valid when 'orient' is 'records'." ) - job_config = bigquery.LoadJobConfig() + job_config = self._prepare_load_job_config() job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY @@ -1518,6 +1531,53 @@ def read_gbq_function( session=self, ) + def _prepare_query_job_config( + self, + job_config: Optional[bigquery.QueryJobConfig] = None, + ) -> bigquery.QueryJobConfig: + if job_config is None: + job_config = bigquery.QueryJobConfig() + else: + # Create a copy so that we don't mutate the original config passed + job_config = typing.cast( + bigquery.QueryJobConfig, + bigquery.QueryJobConfig.from_api_repr(job_config.to_api_repr()), + ) + + if bigframes.options.compute.maximum_bytes_billed is not None: + job_config.maximum_bytes_billed = ( + bigframes.options.compute.maximum_bytes_billed + ) + + if self._bq_kms_key_name: + job_config.destination_encryption_configuration = ( + bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name) + ) + + return job_config + + def _prepare_load_job_config(self) -> bigquery.LoadJobConfig: + # Create a copy so that we don't mutate the original config passed + job_config = bigquery.LoadJobConfig() + + if self._bq_kms_key_name: + job_config.destination_encryption_configuration = ( + bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name) + ) + + return job_config + + def _prepare_copy_job_config(self) -> bigquery.CopyJobConfig: + # Create a copy so that we don't mutate the original config passed + job_config = bigquery.CopyJobConfig() + + if self._bq_kms_key_name: + job_config.destination_encryption_configuration = ( + bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name) + ) + + return job_config + def _start_query( self, sql: str, @@ -1525,29 +1585,30 @@ def _start_query( max_results: Optional[int] = None, ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: """ - Starts query job and waits for results. + Starts BigQuery query job and waits for results. """ - job_config = self._prepare_job_config(job_config) - api_methods = log_adapter.get_and_reset_api_methods() - job_config.labels = bigframes_io.create_job_configs_labels( - job_configs_labels=job_config.labels, api_methods=api_methods + job_config = self._prepare_query_job_config(job_config) + return bigframes.session._io.bigquery.start_query_with_client( + self.bqclient, sql, job_config, max_results ) - try: - query_job = self.bqclient.query(sql, job_config=job_config) - except google.api_core.exceptions.Forbidden as ex: - if "Drive credentials" in ex.message: - ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions." - raise + def _start_query_create_model( + self, + sql: str, + ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: + """ + Starts BigQuery ML CREATE MODEL query job and waits for results. + """ + job_config = self._prepare_query_job_config() - opts = bigframes.options.display - if opts.progress_bar is not None and not query_job.configuration.dry_run: - results_iterator = formatting_helpers.wait_for_query_job( - query_job, max_results, opts.progress_bar - ) - else: - results_iterator = query_job.result(max_results=max_results) - return results_iterator, query_job + # BQML expects kms_key_name through OPTIONS and not through job config, + # so we must reset any encryption set in the job config + # https://cloud.google.com/bigquery/docs/customer-managed-encryption#encrypt-model + job_config.destination_encryption_configuration = None + + return bigframes.session._io.bigquery.start_query_with_client( + self.bqclient, sql, job_config + ) def _cache_with_cluster_cols( self, array_value: core.ArrayValue, cluster_cols: typing.Sequence[str] @@ -1696,19 +1757,6 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob): else: job.result() - def _prepare_job_config( - self, job_config: Optional[bigquery.QueryJobConfig] = None - ) -> bigquery.QueryJobConfig: - if job_config is None: - job_config = self.bqclient.default_query_job_config - if job_config is None: - job_config = bigquery.QueryJobConfig() - if bigframes.options.compute.maximum_bytes_billed is not None: - job_config.maximum_bytes_billed = ( - bigframes.options.compute.maximum_bytes_billed - ) - return job_config - def connect(context: Optional[bigquery_options.BigQueryOptions] = None) -> Session: return Session(context) diff --git a/bigframes/session/_io/bigquery.py b/bigframes/session/_io/bigquery.py index 3695fc98e8..67820bbbcb 100644 --- a/bigframes/session/_io/bigquery.py +++ b/bigframes/session/_io/bigquery.py @@ -20,11 +20,17 @@ import itertools import textwrap import types -from typing import Dict, Iterable, Optional, Sequence, Union +from typing import Dict, Iterable, Optional, Sequence, Tuple, Union import uuid +import google.api_core.exceptions import google.cloud.bigquery as bigquery +import bigframes +from bigframes.core import log_adapter +import bigframes.formatting_helpers as formatting_helpers +import bigframes.session._io.bigquery as bigframes_io + IO_ORDERING_ID = "bqdf_row_nums" MAX_LABELS_COUNT = 64 TEMP_TABLE_PREFIX = "bqdf{date}_{random_id}" @@ -207,3 +213,34 @@ def format_option(key: str, value: Union[bool, str]) -> str: if isinstance(value, bool): return f"{key}=true" if value else f"{key}=false" return f"{key}={repr(value)}" + + +def start_query_with_client( + bq_client: bigquery.Client, + sql: str, + job_config: bigquery.job.QueryJobConfig, + max_results: Optional[int] = None, +) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: + """ + Starts query job and waits for results. + """ + api_methods = log_adapter.get_and_reset_api_methods() + job_config.labels = bigframes_io.create_job_configs_labels( + job_configs_labels=job_config.labels, api_methods=api_methods + ) + + try: + query_job = bq_client.query(sql, job_config=job_config) + except google.api_core.exceptions.Forbidden as ex: + if "Drive credentials" in ex.message: + ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions." + raise + + opts = bigframes.options.display + if opts.progress_bar is not None and not query_job.configuration.dry_run: + results_iterator = formatting_helpers.wait_for_query_job( + query_job, max_results, opts.progress_bar + ) + else: + results_iterator = query_job.result(max_results=max_results) + return results_iterator, query_job diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index 627c9258a6..7574aa4454 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -68,6 +68,7 @@ def __init__( use_regional_endpoints: Optional[bool], credentials: Optional[google.auth.credentials.Credentials], application_name: Optional[str], + bq_kms_key_name: Optional[str], ): credentials_project = None if credentials is None: @@ -98,6 +99,7 @@ def __init__( self._location = location self._use_regional_endpoints = use_regional_endpoints self._credentials = credentials + self._bq_kms_key_name = bq_kms_key_name # cloud clients initialized for lazy load self._bqclient = None @@ -106,28 +108,34 @@ def __init__( self._cloudfunctionsclient = None self._resourcemanagerclient = None + def _create_bigquery_client(self): + bq_options = None + if self._use_regional_endpoints: + bq_options = google.api_core.client_options.ClientOptions( + api_endpoint=( + _BIGQUERY_REGIONAL_ENDPOINT + if self._location.lower() in _REP_SUPPORTED_REGIONS + else _BIGQUERY_LOCATIONAL_ENDPOINT + ).format(location=self._location), + ) + bq_info = google.api_core.client_info.ClientInfo( + user_agent=self._application_name + ) + + bq_client = bigquery.Client( + client_info=bq_info, + client_options=bq_options, + credentials=self._credentials, + project=self._project, + location=self._location, + ) + + return bq_client + @property def bqclient(self): if not self._bqclient: - bq_options = None - if self._use_regional_endpoints: - bq_options = google.api_core.client_options.ClientOptions( - api_endpoint=( - _BIGQUERY_REGIONAL_ENDPOINT - if self._location.lower() in _REP_SUPPORTED_REGIONS - else _BIGQUERY_LOCATIONAL_ENDPOINT - ).format(location=self._location), - ) - bq_info = google.api_core.client_info.ClientInfo( - user_agent=self._application_name - ) - self._bqclient = bigquery.Client( - client_info=bq_info, - client_options=bq_options, - credentials=self._credentials, - project=self._project, - location=self._location, - ) + self._bqclient = self._create_bigquery_client() return self._bqclient diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py new file mode 100644 index 0000000000..0ce9d881fd --- /dev/null +++ b/tests/system/small/test_encryption.py @@ -0,0 +1,256 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import bigquery +import pandas +import pytest + +import bigframes +import bigframes.ml.linear_model + + +@pytest.fixture(scope="module") +def bq_cmek() -> str: + """Customer managed encryption key to encrypt BigQuery data at rest. + + This is of the form projects/PROJECT_ID/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY + + See https://cloud.google.com/bigquery/docs/customer-managed-encryption for steps. + """ + + # NOTE: This key is manually set up through the cloud console + # TODO(shobs): Automate the the key creation during the test. This will + # require extra IAM privileges for the test runner. + return "projects/bigframes-dev-perf/locations/us/keyRings/bigframesKeyRing/cryptoKeys/bigframesKey" + + +@pytest.fixture(scope="module") +def session_with_bq_cmek(bq_cmek) -> bigframes.Session: + session = bigframes.Session(bigframes.BigQueryOptions(kms_key_name=bq_cmek)) + + return session + + +def _assert_bq_table_is_encrypted( + df: bigframes.dataframe.DataFrame, + cmek: str, + session: bigframes.Session, +): + # Materialize the data in BQ + repr(df) + + # The df should be backed by a query job with intended encryption on the result table + assert df.query_job is not None + assert df.query_job.destination_encryption_configuration.kms_key_name.startswith( + cmek + ) + + # The result table should exist with the intended encryption + table = session.bqclient.get_table(df.query_job.destination) + assert table.encryption_configuration.kms_key_name == cmek + + +def test_session_query_job(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + _, query_job = session_with_bq_cmek._start_query( + "SELECT 123", job_config=bigquery.QueryJobConfig(use_query_cache=False) + ) + query_job.result() + + assert query_job.destination_encryption_configuration.kms_key_name.startswith( + bq_cmek + ) + + # The result table should exist with the intended encryption + table = session_with_bq_cmek.bqclient.get_table(query_job.destination) + assert table.encryption_configuration.kms_key_name == bq_cmek + + +def test_session_load_job(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Session should have cmek set in the default query and load job configs + load_table = bigframes.session._io.bigquery.random_table( + session_with_bq_cmek._anonymous_dataset + ) + + df = pandas.DataFrame({"col0": [1, 2, 3]}) + load_job_config = session_with_bq_cmek._prepare_load_job_config() + load_job_config.schema = [ + bigquery.SchemaField(df.columns[0], bigquery.enums.SqlTypeNames.INT64) + ] + + load_job = session_with_bq_cmek.bqclient.load_table_from_dataframe( + df, + load_table, + job_config=load_job_config, + ) + load_job.result() + + assert load_job.destination == load_table + assert load_job.destination_encryption_configuration.kms_key_name.startswith( + bq_cmek + ) + + # The load destination table should be created with the intended encryption + table = session_with_bq_cmek.bqclient.get_table(load_job.destination) + assert table.encryption_configuration.kms_key_name == bq_cmek + + +def test_read_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Read the BQ table + df = session_with_bq_cmek.read_gbq(scalars_table_id) + + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +def test_df_apis(bq_cmek, session_with_bq_cmek, scalars_table_id): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Read a BQ table and assert encryption + df = session_with_bq_cmek.read_gbq(scalars_table_id) + + # Perform a few dataframe operations and assert assertion + df1 = df.dropna() + _assert_bq_table_is_encrypted(df1, bq_cmek, session_with_bq_cmek) + + df2 = df1.head() + _assert_bq_table_is_encrypted(df2, bq_cmek, session_with_bq_cmek) + + +@pytest.mark.parametrize( + "engine", + [ + pytest.param("bigquery", id="bq_engine"), + pytest.param( + None, + id="default_engine", + marks=pytest.mark.skip( + reason="Internal issue 327544164, cmek does not propagate to the dataframe." + ), + ), + ], +) +def test_read_csv_gcs( + bq_cmek, session_with_bq_cmek, scalars_df_index, gcs_folder, engine +): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Create a csv in gcs + write_path = gcs_folder + "test_read_csv_gcs_bigquery_engine*.csv" + read_path = ( + write_path.replace("*", "000000000000") if engine is None else write_path + ) + scalars_df_index.to_csv(write_path) + + # Read the BQ table + df = session_with_bq_cmek.read_csv(read_path, engine=engine) + + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +def test_to_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Read a BQ table and assert encryption + df = session_with_bq_cmek.read_gbq(scalars_table_id) + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + # Modify the dataframe and assert assertion + df = df.dropna().head() + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + # Write the result to BQ and assert assertion + output_table_id = df.to_gbq() + output_table = session_with_bq_cmek.bqclient.get_table(output_table_id) + assert output_table.encryption_configuration.kms_key_name == bq_cmek + + +@pytest.mark.skip( + reason="Internal issue 327544164, cmek does not propagate to the dataframe." +) +def test_read_pandas(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Read a pandas dataframe + df = session_with_bq_cmek.read_pandas(pandas.DataFrame([1])) + + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +def test_read_pandas_large(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Read a pandas dataframe large enough to trigger a BQ load job + df = session_with_bq_cmek.read_pandas(pandas.DataFrame(range(10_000))) + + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +def test_bqml(bq_cmek, session_with_bq_cmek, penguins_table_id): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + model = bigframes.ml.linear_model.LinearRegression() + df = session_with_bq_cmek.read_gbq(penguins_table_id).dropna() + X_train = df[ + [ + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "sex", + ] + ] + y_train = df[["body_mass_g"]] + model.fit(X_train, y_train) + + assert model is not None + assert model._bqml_model.model.encryption_configuration is not None + assert model._bqml_model.model.encryption_configuration.kms_key_name == bq_cmek + + # Assert that model exists in BQ with intended encryption + model_bq = session_with_bq_cmek.bqclient.get_model(model._bqml_model.model_name) + assert model_bq.encryption_configuration.kms_key_name == bq_cmek + + # Explicitly save the model to a destination and assert that encryption holds + model_ref = model._bqml_model_factory._create_model_ref( + session_with_bq_cmek._anonymous_dataset + ) + model_ref_full_name = ( + f"{model_ref.project}.{model_ref.dataset_id}.{model_ref.model_id}" + ) + new_model = model.to_gbq(model_ref_full_name) + assert new_model._bqml_model.model.encryption_configuration.kms_key_name == bq_cmek + + # Assert that model exists in BQ with intended encryption + model_bq = session_with_bq_cmek.bqclient.get_model(new_model._bqml_model.model_name) + assert model_bq.encryption_configuration.kms_key_name == bq_cmek diff --git a/tests/unit/_config/test_bigquery_options.py b/tests/unit/_config/test_bigquery_options.py index e5b6cfe2f1..1ce70e3da2 100644 --- a/tests/unit/_config/test_bigquery_options.py +++ b/tests/unit/_config/test_bigquery_options.py @@ -29,6 +29,7 @@ ("project", "my-project", "my-other-project"), ("bq_connection", "path/to/connection/1", "path/to/connection/2"), ("use_regional_endpoints", False, True), + ("kms_key_name", "kms/key/name/1", "kms/key/name/2"), ], ) def test_setter_raises_if_session_started(attribute, original_value, new_value): @@ -61,6 +62,7 @@ def test_setter_raises_if_session_started(attribute, original_value, new_value): "project", "bq_connection", "use_regional_endpoints", + "bq_kms_key_name", ] ], ) diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 017c96d46d..25e12d87bf 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -35,6 +35,7 @@ def mock_session(): mock_session._anonymous_dataset = bigquery.DatasetReference( TEMP_MODEL_ID.project, TEMP_MODEL_ID.dataset_id ) + mock_session._bq_kms_key_name = None query_job = mock.create_autospec(bigquery.QueryJob) type(query_job).destination = mock.PropertyMock( @@ -42,7 +43,7 @@ def mock_session(): mock_session._anonymous_dataset, TEMP_MODEL_ID.model_id ) ) - mock_session._start_query.return_value = (None, query_job) + mock_session._start_query_create_model.return_value = (None, query_job) return mock_session @@ -103,7 +104,7 @@ def test_linear_regression_default_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -113,7 +114,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X, model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -146,7 +147,7 @@ def test_logistic_regression_default_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -160,7 +161,7 @@ def test_logistic_regression_params_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) diff --git a/tests/unit/session/test_clients.py b/tests/unit/session/test_clients.py index f1b2a5045a..30ba2f9091 100644 --- a/tests/unit/session/test_clients.py +++ b/tests/unit/session/test_clients.py @@ -38,6 +38,7 @@ def create_clients_provider(application_name: Optional[str] = None): use_regional_endpoints=False, credentials=credentials, application_name=application_name, + bq_kms_key_name="projects/my-project/locations/us/keyRings/myKeyRing/cryptoKeys/myKey", )