From 144190e790cb663dec5d60392e8c96240cd26087 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Thu, 29 Feb 2024 08:10:44 +0000 Subject: [PATCH 01/16] feat: Support CMEK for BQ tables --- bigframes/_config/bigquery_options.py | 25 ++++++ bigframes/session/__init__.py | 13 +++ bigframes/session/clients.py | 17 ++++ tests/system/small/test_encryption.py | 120 ++++++++++++++++++++++++++ 4 files changed, 175 insertions(+) create mode 100644 tests/system/small/test_encryption.py diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index 2875a11de3..ee3ff56ba4 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -39,6 +39,7 @@ def __init__( bq_connection: Optional[str] = None, use_regional_endpoints: bool = False, application_name: Optional[str] = None, + kms_key_name: Optional[str] = None, ): self._credentials = credentials self._project = project @@ -46,6 +47,7 @@ def __init__( self._bq_connection = bq_connection self._use_regional_endpoints = use_regional_endpoints self._application_name = application_name + self._kms_key_name = kms_key_name self._session_started = False @property @@ -150,3 +152,26 @@ def use_regional_endpoints(self, value: bool): ) self._use_regional_endpoints = value + + @property + def kms_key_name(self) -> Optional[str]: + """Customer managed encryption key used to control encryption of the + data-at-rest in BigQuery. This is of the format + projects/PROJECT_ID/locations/LOCATION/keyRings/KEYRING/cryptoKeys/KEY + + See https://cloud.google.com/bigquery/docs/customer-managed-encryption + for more details. + + Please make sure the project used for Bigquery DataFrames has "Cloud KMS + CryptoKey Encrypter/Decrypter" role in the key's project, See + https://cloud.google.com/bigquery/docs/customer-managed-encryption#assign_role + for steps on how to ensure that. + """ + return self._kms_key_name + + @kms_key_name.setter + def kms_key_name(self, value: str): + if self._session_started and self._kms_key_name != value: + raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="kms_key_name")) + + self._kms_key_name = value diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 20dd39c0fa..0f427489f8 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -161,6 +161,8 @@ def __init__( else: self._location = context.location + self._bq_kms_key_name = context.kms_key_name + # Instantiate a clients provider to help with cloud clients that will be # used in the future operations in the session if clients_provider: @@ -172,9 +174,17 @@ def __init__( use_regional_endpoints=context.use_regional_endpoints, credentials=context.credentials, application_name=context.application_name, + bq_kms_key_name=self._bq_kms_key_name, ) self._create_bq_datasets() + + # TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494 + # has been fixed. The ibis client changes the default query job config + # so we are going to remember the current config and restore it after + # the ibis client has been created + original_default_query_job_config = self.bqclient.default_query_job_config + self.ibis_client = typing.cast( ibis_bigquery.Backend, ibis.bigquery.connect( @@ -184,6 +194,9 @@ def __init__( ), ) + self.bqclient.default_query_job_config = original_default_query_job_config + + # Resolve the BQ connection for remote function and Vertex AI integration self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID # Now that we're starting the session, don't allow the options to be diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index e33413002f..432ddddb7e 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -60,6 +60,7 @@ def __init__( use_regional_endpoints: Optional[bool], credentials: Optional[google.auth.credentials.Credentials], application_name: Optional[str], + bq_kms_key_name: Optional[str], ): credentials_project = None if credentials is None: @@ -90,6 +91,7 @@ def __init__( self._location = location self._use_regional_endpoints = use_regional_endpoints self._credentials = credentials + self._bq_kms_key_name = bq_kms_key_name # cloud clients initialized for lazy load self._bqclient = None @@ -111,12 +113,27 @@ def bqclient(self): bq_info = google.api_core.client_info.ClientInfo( user_agent=self._application_name ) + default_query_job_config = None + default_load_job_config = None + if self._bq_kms_key_name: + default_query_job_config = bigquery.QueryJobConfig( + destination_encryption_configuration=bigquery.EncryptionConfiguration( + kms_key_name=self._bq_kms_key_name + ) + ) + # default_load_job_config = bigquery.LoadJobConfig( + # destination_encryption_configuration = bigquery.EncryptionConfiguration( + # kms_key_name=self._bq_kms_key_name + # ) + # ) self._bqclient = bigquery.Client( client_info=bq_info, client_options=bq_options, credentials=self._credentials, project=self._project, location=self._location, + default_query_job_config=default_query_job_config, + default_load_job_config=default_load_job_config, ) return self._bqclient diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py new file mode 100644 index 0000000000..9b9a5235da --- /dev/null +++ b/tests/system/small/test_encryption.py @@ -0,0 +1,120 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas +import pytest + +import bigframes + + +@pytest.fixture(scope="module") +def bq_cmek() -> str: + """Customer managed encryption key to encrypt BigQuery data at rest. + + This is of the form projects/PROJECT_ID/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY + + See https://cloud.google.com/bigquery/docs/customer-managed-encryption for steps. + """ + + # We are keeping this empty, and the dependent tests are skipped if it is empty. + # In local testing we can set a key here and run the tests. + # TODO(shobs): Automate the tests depending on this fixture by either creating + # a static key in the test project or automating the key creation during the test. + return "" + + +@pytest.fixture(scope="module") +def session_with_bq_cmek(bq_cmek) -> bigframes.Session: + session = bigframes.Session(bigframes.BigQueryOptions(kms_key_name=bq_cmek)) + + return session + + +def _assert_bq_table_is_encrypted( + df: bigframes.dataframe.DataFrame, + cmek: str, + session: bigframes.Session, +): + # Materialize the data in BQ + repr(df) + + # The df should be backed by a query job with intended encryption on the result table + assert df.query_job is not None + assert df.query_job.destination_encryption_configuration.kms_key_name.startswith( + cmek + ) + + # The result table should exist with the intended encryption + table = session.bqclient.get_table(df.query_job.destination) + assert table.encryption_configuration.kms_key_name == cmek + + +def test_read_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Session should have cmek set in the default query job config + assert ( + session_with_bq_cmek.bqclient.default_query_job_config.destination_encryption_configuration.kms_key_name + == bq_cmek + ) + + # Read the BQ table + df = session_with_bq_cmek.read_gbq(scalars_table_id) + + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +def test_read_csv_gcs_bigquery_engine( + bq_cmek, session_with_bq_cmek, scalars_df_index, gcs_folder +): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Session should have cmek set in the default query job config + assert ( + session_with_bq_cmek.bqclient.default_query_job_config.destination_encryption_configuration.kms_key_name + == bq_cmek + ) + + # Create a csv in gcs + path = gcs_folder + "test_read_csv_gcs_bigquery_engine*.csv" + scalars_df_index.to_csv(path) + + # Read the BQ table + df = session_with_bq_cmek.read_csv(path, engine="bigquery") + + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +@pytest.mark.skip( + reason="Internal issue 327544164, cmek does not propagate to the dataframe." +) +def test_read_pandas(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Session should have cmek set in the default query job config + assert ( + session_with_bq_cmek.bqclient.default_query_job_config.destination_encryption_configuration.kms_key_name + == bq_cmek + ) + + # Read the BQ table + df = session_with_bq_cmek.read_pandas(pandas.DataFrame([[1]])) + + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) From 6b41ec0dfd486e00348d877ce57f97936e072df4 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 1 Mar 2024 01:43:53 +0000 Subject: [PATCH 02/16] add more tests --- bigframes/session/clients.py | 10 +-- tests/system/small/test_encryption.py | 102 ++++++++++++++++++++------ 2 files changed, 85 insertions(+), 27 deletions(-) diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index 3f402ade37..a1eff612e3 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -131,11 +131,11 @@ def bqclient(self): kms_key_name=self._bq_kms_key_name ) ) - # default_load_job_config = bigquery.LoadJobConfig( - # destination_encryption_configuration = bigquery.EncryptionConfiguration( - # kms_key_name=self._bq_kms_key_name - # ) - # ) + default_load_job_config = bigquery.LoadJobConfig( + destination_encryption_configuration=bigquery.EncryptionConfiguration( + kms_key_name=self._bq_kms_key_name + ) + ) self._bqclient = bigquery.Client( client_info=bq_info, client_options=bq_options, diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py index 9b9a5235da..9f2dd1755d 100644 --- a/tests/system/small/test_encryption.py +++ b/tests/system/small/test_encryption.py @@ -27,8 +27,9 @@ def bq_cmek() -> str: See https://cloud.google.com/bigquery/docs/customer-managed-encryption for steps. """ - # We are keeping this empty, and the dependent tests are skipped if it is empty. - # In local testing we can set a key here and run the tests. + # NOTE: Set a key here for local testing. + # We are keeping this empty by default, in which case the dependent tests + # are skipped. # TODO(shobs): Automate the tests depending on this fixture by either creating # a static key in the test project or automating the key creation during the test. return "" @@ -60,15 +61,24 @@ def _assert_bq_table_is_encrypted( assert table.encryption_configuration.kms_key_name == cmek -def test_read_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): +def test_session_default_configs(bq_cmek, session_with_bq_cmek): if not bq_cmek: pytest.skip("no cmek set for testing") - # Session should have cmek set in the default query job config + # Session should have cmek set in the default query and load job configs assert ( session_with_bq_cmek.bqclient.default_query_job_config.destination_encryption_configuration.kms_key_name == bq_cmek ) + assert ( + session_with_bq_cmek.bqclient.default_load_job_config.destination_encryption_configuration.kms_key_name + == bq_cmek + ) + + +def test_read_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): + if not bq_cmek: + pytest.skip("no cmek set for testing") # Read the BQ table df = session_with_bq_cmek.read_gbq(scalars_table_id) @@ -77,29 +87,72 @@ def test_read_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) -def test_read_csv_gcs_bigquery_engine( - bq_cmek, session_with_bq_cmek, scalars_df_index, gcs_folder -): +def test_df_apis(bq_cmek, session_with_bq_cmek, scalars_table_id): if not bq_cmek: pytest.skip("no cmek set for testing") - # Session should have cmek set in the default query job config - assert ( - session_with_bq_cmek.bqclient.default_query_job_config.destination_encryption_configuration.kms_key_name - == bq_cmek - ) + # Read a BQ table and assert encryption + df = session_with_bq_cmek.read_gbq(scalars_table_id) + + # Perform a few dataframe operations and assert assertion + df1 = df.dropna() + _assert_bq_table_is_encrypted(df1, bq_cmek, session_with_bq_cmek) + + df2 = df1.head() + _assert_bq_table_is_encrypted(df2, bq_cmek, session_with_bq_cmek) + + +@pytest.mark.parametrize( + "engine", + [ + pytest.param("bigquery", id="bq_engine"), + pytest.param( + None, + id="default_engine", + marks=pytest.mark.skip( + reason="Internal issue 327544164, cmek does not propagate to the dataframe." + ), + ), + ], +) +def test_read_csv_gcs( + bq_cmek, session_with_bq_cmek, scalars_df_index, gcs_folder, engine +): + if not bq_cmek: + pytest.skip("no cmek set for testing") # Create a csv in gcs - path = gcs_folder + "test_read_csv_gcs_bigquery_engine*.csv" - scalars_df_index.to_csv(path) + write_path = gcs_folder + "test_read_csv_gcs_bigquery_engine*.csv" + read_path = ( + write_path.replace("*", "000000000000") if engine is None else write_path + ) + scalars_df_index.to_csv(write_path) # Read the BQ table - df = session_with_bq_cmek.read_csv(path, engine="bigquery") + df = session_with_bq_cmek.read_csv(read_path, engine=engine) # Assert encryption _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) +def test_to_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Read a BQ table and assert encryption + df = session_with_bq_cmek.read_gbq(scalars_table_id) + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + # Modify the dataframe and assert assertion + df = df.dropna().head() + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + # Write the result to BQ and assert assertion + output_table_id = df.to_gbq() + output_table = session_with_bq_cmek.bqclient.get_table(output_table_id) + assert output_table.encryption_configuration.kms_key_name == bq_cmek + + @pytest.mark.skip( reason="Internal issue 327544164, cmek does not propagate to the dataframe." ) @@ -107,14 +160,19 @@ def test_read_pandas(bq_cmek, session_with_bq_cmek): if not bq_cmek: pytest.skip("no cmek set for testing") - # Session should have cmek set in the default query job config - assert ( - session_with_bq_cmek.bqclient.default_query_job_config.destination_encryption_configuration.kms_key_name - == bq_cmek - ) + # Read a pandas dataframe + df = session_with_bq_cmek.read_pandas(pandas.DataFrame([1])) - # Read the BQ table - df = session_with_bq_cmek.read_pandas(pandas.DataFrame([[1]])) + # Assert encryption + _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +def test_read_pandas_large(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Read a pandas dataframe large enough to trigger a BQ load job + df = session_with_bq_cmek.read_pandas(pandas.DataFrame(range(10_000))) # Assert encryption _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) From d2f3f3d82007f095e5744ac38a324891adfbc79c Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 1 Mar 2024 01:55:58 +0000 Subject: [PATCH 03/16] add unit tests --- tests/unit/_config/test_bigquery_options.py | 2 ++ tests/unit/session/test_clients.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/unit/_config/test_bigquery_options.py b/tests/unit/_config/test_bigquery_options.py index e5b6cfe2f1..1ce70e3da2 100644 --- a/tests/unit/_config/test_bigquery_options.py +++ b/tests/unit/_config/test_bigquery_options.py @@ -29,6 +29,7 @@ ("project", "my-project", "my-other-project"), ("bq_connection", "path/to/connection/1", "path/to/connection/2"), ("use_regional_endpoints", False, True), + ("kms_key_name", "kms/key/name/1", "kms/key/name/2"), ], ) def test_setter_raises_if_session_started(attribute, original_value, new_value): @@ -61,6 +62,7 @@ def test_setter_raises_if_session_started(attribute, original_value, new_value): "project", "bq_connection", "use_regional_endpoints", + "bq_kms_key_name", ] ], ) diff --git a/tests/unit/session/test_clients.py b/tests/unit/session/test_clients.py index f1b2a5045a..30ba2f9091 100644 --- a/tests/unit/session/test_clients.py +++ b/tests/unit/session/test_clients.py @@ -38,6 +38,7 @@ def create_clients_provider(application_name: Optional[str] = None): use_regional_endpoints=False, credentials=credentials, application_name=application_name, + bq_kms_key_name="projects/my-project/locations/us/keyRings/myKeyRing/cryptoKeys/myKey", ) From 425cf1217ae242d068776e0c85987f23a1f62ce2 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 1 Mar 2024 02:35:37 +0000 Subject: [PATCH 04/16] add more tests, fix broken tests --- bigframes/pandas/__init__.py | 1 + tests/system/small/test_encryption.py | 45 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index 110978a7f1..f2d7fb2cb1 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -383,6 +383,7 @@ def _set_default_session_location_if_possible(query): use_regional_endpoints=options.bigquery.use_regional_endpoints, credentials=options.bigquery.credentials, application_name=options.bigquery.application_name, + bq_kms_key_name=options.bigquery.kms_key_name, ) bqclient = clients_provider.bqclient diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py index 9f2dd1755d..a09c7e1848 100644 --- a/tests/system/small/test_encryption.py +++ b/tests/system/small/test_encryption.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +from google.cloud import bigquery import pandas import pytest @@ -76,6 +78,49 @@ def test_session_default_configs(bq_cmek, session_with_bq_cmek): ) +def test_session_query_job(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + query_job = session_with_bq_cmek.bqclient.query("SELECT 123") + query_job.result() + + assert query_job.destination_encryption_configuration.kms_key_name.startswith( + bq_cmek + ) + + # The result table should exist with the intended encryption + table = session_with_bq_cmek.bqclient.get_table(query_job.destination) + assert table.encryption_configuration.kms_key_name == bq_cmek + + +def test_session_load_job(bq_cmek, session_with_bq_cmek): + if not bq_cmek: + pytest.skip("no cmek set for testing") + + # Session should have cmek set in the default query and load job configs + load_table = bigframes.session._io.bigquery.random_table( + session_with_bq_cmek._anonymous_dataset + ) + load_job = session_with_bq_cmek.bqclient.load_table_from_dataframe( + pandas.DataFrame({"col0": [1, 2, 3]}), + load_table, + job_config=bigquery.LoadJobConfig( + schema=[bigquery.SchemaField("col0", bigquery.enums.SqlTypeNames.INT64)] + ), + ) + load_job.result() + + assert load_job.destination == load_table + assert load_job.destination_encryption_configuration.kms_key_name.startswith( + bq_cmek + ) + + # The result table should exist with the intended encryption + table = session_with_bq_cmek.bqclient.get_table(load_job.destination) + assert table.encryption_configuration.kms_key_name == bq_cmek + + def test_read_gbq(bq_cmek, session_with_bq_cmek, scalars_table_id): if not bq_cmek: pytest.skip("no cmek set for testing") From 2443d4029a467fb98a5d3cdb5e2a567b16fcf69a Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 4 Mar 2024 18:28:15 +0000 Subject: [PATCH 05/16] separate bqml client to send kms_key_name via OPTIONS instead of job config --- bigframes/ml/core.py | 13 ++++-- bigframes/session/__init__.py | 43 +++++++++++++++---- bigframes/session/clients.py | 60 ++++++++++++++++----------- tests/system/small/test_encryption.py | 43 ++++++++++++++++++- tests/unit/session/test_clients.py | 1 + 5 files changed, 124 insertions(+), 36 deletions(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index c496133aa7..9793208e41 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -212,7 +212,11 @@ def principal_component_info(self) -> bpd.DataFrame: return self._session.read_gbq(sql) def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel: - job_config = bigquery.job.CopyJobConfig() + job_config = bigquery.job.CopyJobConfig( + destination_encryption_configuration=bigquery.EncryptionConfiguration( + kms_key_name=self._session._bq_kms_key_name + ) + ) if replace: job_config.write_disposition = "WRITE_TRUNCATE" @@ -236,7 +240,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel: options={"vertex_ai_model_id": vertex_ai_model_id} ) # Register the model and wait it to finish - self._session._start_query(sql) + self._session._start_query_bqml(sql) self._model = self._session.bqclient.get_model(self.model_name) return self @@ -255,7 +259,7 @@ def _create_model_ref( def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel: # fit the model, synchronously - _, job = session._start_query(sql) + _, job = session._start_query_bqml(sql) # real model path in the session specific hidden dataset and table prefix model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}" @@ -298,6 +302,9 @@ def create_model( options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()}) session = X_train._session + if session._bq_kms_key_name: + options.update({"kms_key_name": session._bq_kms_key_name}) + model_ref = self._create_model_ref(session._anonymous_dataset) sql = self._model_creation_sql_generator.create_model( diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 66b087716a..a6090b9c65 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -208,6 +208,10 @@ def __init__( def bqclient(self): return self._clients_provider.bqclient + @property + def bqmlclient(self): + return self._clients_provider.bqmlclient + @property def bqconnectionclient(self): return self._clients_provider.bqconnectionclient @@ -1509,8 +1513,8 @@ def read_gbq_function( session=self, ) - def _start_query( - self, + def _start_query_with_client( + bq_client: bigquery.Client, sql: str, job_config: Optional[bigquery.job.QueryJobConfig] = None, max_results: Optional[int] = None, @@ -1518,14 +1522,14 @@ def _start_query( """ Starts query job and waits for results. """ - job_config = self._prepare_job_config(job_config) + job_config = Session._prepare_job_config(job_config) api_methods = log_adapter.get_and_reset_api_methods() job_config.labels = bigframes_io.create_job_configs_labels( job_configs_labels=job_config.labels, api_methods=api_methods ) try: - query_job = self.bqclient.query(sql, job_config=job_config) + query_job = bq_client.query(sql, job_config=job_config) except google.api_core.exceptions.Forbidden as ex: if "Drive credentials" in ex.message: ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions." @@ -1540,6 +1544,32 @@ def _start_query( results_iterator = query_job.result(max_results=max_results) return results_iterator, query_job + def _start_query( + self, + sql: str, + job_config: Optional[bigquery.job.QueryJobConfig] = None, + max_results: Optional[int] = None, + ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: + """ + Starts BigQuery query job and waits for results. + """ + return Session._start_query_with_client( + self.bqclient, sql, job_config, max_results + ) + + def _start_query_bqml( + self, + sql: str, + job_config: Optional[bigquery.job.QueryJobConfig] = None, + max_results: Optional[int] = None, + ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: + """ + Starts BigQuery ML query job and waits for results. + """ + return Session._start_query_with_client( + self.bqmlclient, sql, job_config, max_results + ) + def _cache_with_cluster_cols( self, array_value: core.ArrayValue, cluster_cols: typing.Sequence[str] ) -> core.ArrayValue: @@ -1681,11 +1711,10 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob): else: job.result() + @staticmethod def _prepare_job_config( - self, job_config: Optional[bigquery.QueryJobConfig] = None + job_config: Optional[bigquery.QueryJobConfig] = None, ) -> bigquery.QueryJobConfig: - if job_config is None: - job_config = self.bqclient.default_query_job_config if job_config is None: job_config = bigquery.QueryJobConfig() if bigframes.options.compute.maximum_bytes_billed is not None: diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index a1eff612e3..7e7456b07f 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -103,51 +103,61 @@ def __init__( # cloud clients initialized for lazy load self._bqclient = None + self._bqmlclient = None self._bqconnectionclient = None self._bqstoragereadclient = None self._cloudfunctionsclient = None self._resourcemanagerclient = None + def _create_bigquery_client(self): + bq_options = None + if self._use_regional_endpoints: + bq_options = google.api_core.client_options.ClientOptions( + api_endpoint=( + _BIGQUERY_REGIONAL_ENDPOINT + if self._location.lower() in _REP_SUPPORTED_REGIONS + else _BIGQUERY_LOCATIONAL_ENDPOINT + ).format(location=self._location), + ) + bq_info = google.api_core.client_info.ClientInfo( + user_agent=self._application_name + ) + + bq_client = bigquery.Client( + client_info=bq_info, + client_options=bq_options, + credentials=self._credentials, + project=self._project, + location=self._location, + ) + + return bq_client + @property def bqclient(self): if not self._bqclient: - bq_options = None - if self._use_regional_endpoints: - bq_options = google.api_core.client_options.ClientOptions( - api_endpoint=( - _BIGQUERY_REGIONAL_ENDPOINT - if self._location.lower() in _REP_SUPPORTED_REGIONS - else _BIGQUERY_LOCATIONAL_ENDPOINT - ).format(location=self._location), - ) - bq_info = google.api_core.client_info.ClientInfo( - user_agent=self._application_name - ) - default_query_job_config = None - default_load_job_config = None + self._bqclient = self._create_bigquery_client() if self._bq_kms_key_name: - default_query_job_config = bigquery.QueryJobConfig( + self._bqclient.default_query_job_config = bigquery.QueryJobConfig( destination_encryption_configuration=bigquery.EncryptionConfiguration( kms_key_name=self._bq_kms_key_name ) ) - default_load_job_config = bigquery.LoadJobConfig( + self._bqclient.default_load_job_config = bigquery.LoadJobConfig( destination_encryption_configuration=bigquery.EncryptionConfiguration( kms_key_name=self._bq_kms_key_name ) ) - self._bqclient = bigquery.Client( - client_info=bq_info, - client_options=bq_options, - credentials=self._credentials, - project=self._project, - location=self._location, - default_query_job_config=default_query_job_config, - default_load_job_config=default_load_job_config, - ) return self._bqclient + @property + def bqmlclient(self): + if not self._bqmlclient: + self._bqmlclient = self._create_bigquery_client() + + return self._bqmlclient + @property def bqconnectionclient(self): if not self._bqconnectionclient: diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py index a09c7e1848..f0fd42e56e 100644 --- a/tests/system/small/test_encryption.py +++ b/tests/system/small/test_encryption.py @@ -18,6 +18,7 @@ import pytest import bigframes +import bigframes.ml.linear_model @pytest.fixture(scope="module") @@ -116,7 +117,7 @@ def test_session_load_job(bq_cmek, session_with_bq_cmek): bq_cmek ) - # The result table should exist with the intended encryption + # The load destination table should be created with the intended encryption table = session_with_bq_cmek.bqclient.get_table(load_job.destination) assert table.encryption_configuration.kms_key_name == bq_cmek @@ -221,3 +222,43 @@ def test_read_pandas_large(bq_cmek, session_with_bq_cmek): # Assert encryption _assert_bq_table_is_encrypted(df, bq_cmek, session_with_bq_cmek) + + +def test_bqml(bq_cmek, session_with_bq_cmek, penguins_table_id): + model = bigframes.ml.linear_model.LinearRegression() + + df = session_with_bq_cmek.read_gbq(penguins_table_id).dropna() + X_train = df[ + [ + "species", + "island", + "culmen_length_mm", + "culmen_depth_mm", + "flipper_length_mm", + "sex", + ] + ] + y_train = df[["body_mass_g"]] + model.fit(X_train, y_train) + + assert model is not None + assert model._bqml_model.model.encryption_configuration is not None + assert model._bqml_model.model.encryption_configuration.kms_key_name == bq_cmek + + # Assert that model exists in BQ with intended encryption + model_bq = session_with_bq_cmek.bqclient.get_model(model._bqml_model.model_name) + assert model_bq.encryption_configuration.kms_key_name == bq_cmek + + # Explicitly save the model to a destination and assert that encryption holds + model_ref = model._bqml_model_factory._create_model_ref( + session_with_bq_cmek._anonymous_dataset + ) + model_ref_full_name = ( + f"{model_ref.project}.{model_ref.dataset_id}.{model_ref.model_id}" + ) + new_model = model.to_gbq(model_ref_full_name) + assert new_model._bqml_model.model.encryption_configuration.kms_key_name == bq_cmek + + # Assert that model exists in BQ with intended encryption + model_bq = session_with_bq_cmek.bqclient.get_model(new_model._bqml_model.model_name) + assert model_bq.encryption_configuration.kms_key_name == bq_cmek diff --git a/tests/unit/session/test_clients.py b/tests/unit/session/test_clients.py index 30ba2f9091..74a721d6d8 100644 --- a/tests/unit/session/test_clients.py +++ b/tests/unit/session/test_clients.py @@ -93,6 +93,7 @@ def assert_clients_w_user_agent( provider: clients.ClientsProvider, expected_user_agent: str ): assert_constructed_w_user_agent(provider.bqclient, expected_user_agent) + assert_constructed_w_user_agent(provider.bqmlclient, expected_user_agent) assert_constructed_w_user_agent(provider.bqconnectionclient, expected_user_agent) assert_constructed_w_user_agent(provider.bqstoragereadclient, expected_user_agent) assert_constructed_w_user_agent(provider.cloudfunctionsclient, expected_user_agent) From 07d97f44c6ce01db476cc8ddc5c9f7e5d836e946 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 4 Mar 2024 20:13:01 +0000 Subject: [PATCH 06/16] fix unit tests --- tests/unit/ml/test_golden_sql.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 017c96d46d..a3c9ed6e87 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -35,6 +35,7 @@ def mock_session(): mock_session._anonymous_dataset = bigquery.DatasetReference( TEMP_MODEL_ID.project, TEMP_MODEL_ID.dataset_id ) + mock_session._bq_kms_key_name = None query_job = mock.create_autospec(bigquery.QueryJob) type(query_job).destination = mock.PropertyMock( @@ -42,7 +43,7 @@ def mock_session(): mock_session._anonymous_dataset, TEMP_MODEL_ID.model_id ) ) - mock_session._start_query.return_value = (None, query_job) + mock_session._start_query_bqml.return_value = (None, query_job) return mock_session @@ -103,7 +104,7 @@ def test_linear_regression_default_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_bqml.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -113,7 +114,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X, model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_bqml.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -146,7 +147,7 @@ def test_logistic_regression_default_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_bqml.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -160,7 +161,7 @@ def test_logistic_regression_params_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query.assert_called_once_with( + mock_session._start_query_bqml.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) From 950dd27efebcd5d2a323829baa027c3d328a9210 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 4 Mar 2024 20:32:05 +0000 Subject: [PATCH 07/16] fix mypy --- bigframes/session/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index a6090b9c65..005a19fd8b 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -1513,6 +1513,7 @@ def read_gbq_function( session=self, ) + @staticmethod def _start_query_with_client( bq_client: bigquery.Client, sql: str, From 927d8a220baf2fa31eacf100e500fcabff8c3f19 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Mon, 4 Mar 2024 23:04:09 +0000 Subject: [PATCH 08/16] skip cmek test for empty cmek --- tests/system/small/test_encryption.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py index f0fd42e56e..3c53920c78 100644 --- a/tests/system/small/test_encryption.py +++ b/tests/system/small/test_encryption.py @@ -225,8 +225,10 @@ def test_read_pandas_large(bq_cmek, session_with_bq_cmek): def test_bqml(bq_cmek, session_with_bq_cmek, penguins_table_id): - model = bigframes.ml.linear_model.LinearRegression() + if not bq_cmek: + pytest.skip("no cmek set for testing") + model = bigframes.ml.linear_model.LinearRegression() df = session_with_bq_cmek.read_gbq(penguins_table_id).dropna() X_train = df[ [ From 827bef2a78b1d7f6ad03c949ff3852ff67493358 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Tue, 5 Mar 2024 23:57:18 +0000 Subject: [PATCH 09/16] move staticmethods to helper module --- bigframes/session/__init__.py | 50 ++----------------------------- bigframes/session/_io/bigquery.py | 50 ++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 49 deletions(-) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 005a19fd8b..f2944b2c6e 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -65,7 +65,6 @@ import bigframes._config.bigquery_options as bigquery_options import bigframes.constants as constants -from bigframes.core import log_adapter import bigframes.core as core import bigframes.core.blocks as blocks import bigframes.core.compile @@ -84,7 +83,6 @@ # Even though the ibis.backends.bigquery import is unused, it's needed # to register new and replacement ops with the Ibis BigQuery backend. -import third_party.bigframes_vendored.ibis.backends.bigquery # noqa import third_party.bigframes_vendored.ibis.expr.operations as vendored_ibis_ops import third_party.bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq import third_party.bigframes_vendored.pandas.io.parquet as third_party_pandas_parquet @@ -1513,38 +1511,6 @@ def read_gbq_function( session=self, ) - @staticmethod - def _start_query_with_client( - bq_client: bigquery.Client, - sql: str, - job_config: Optional[bigquery.job.QueryJobConfig] = None, - max_results: Optional[int] = None, - ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: - """ - Starts query job and waits for results. - """ - job_config = Session._prepare_job_config(job_config) - api_methods = log_adapter.get_and_reset_api_methods() - job_config.labels = bigframes_io.create_job_configs_labels( - job_configs_labels=job_config.labels, api_methods=api_methods - ) - - try: - query_job = bq_client.query(sql, job_config=job_config) - except google.api_core.exceptions.Forbidden as ex: - if "Drive credentials" in ex.message: - ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions." - raise - - opts = bigframes.options.display - if opts.progress_bar is not None and not query_job.configuration.dry_run: - results_iterator = formatting_helpers.wait_for_query_job( - query_job, max_results, opts.progress_bar - ) - else: - results_iterator = query_job.result(max_results=max_results) - return results_iterator, query_job - def _start_query( self, sql: str, @@ -1554,7 +1520,7 @@ def _start_query( """ Starts BigQuery query job and waits for results. """ - return Session._start_query_with_client( + return bigframes.session._io.bigquery.start_query_with_client( self.bqclient, sql, job_config, max_results ) @@ -1567,7 +1533,7 @@ def _start_query_bqml( """ Starts BigQuery ML query job and waits for results. """ - return Session._start_query_with_client( + return bigframes.session._io.bigquery.start_query_with_client( self.bqmlclient, sql, job_config, max_results ) @@ -1712,18 +1678,6 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob): else: job.result() - @staticmethod - def _prepare_job_config( - job_config: Optional[bigquery.QueryJobConfig] = None, - ) -> bigquery.QueryJobConfig: - if job_config is None: - job_config = bigquery.QueryJobConfig() - if bigframes.options.compute.maximum_bytes_billed is not None: - job_config.maximum_bytes_billed = ( - bigframes.options.compute.maximum_bytes_billed - ) - return job_config - def connect(context: Optional[bigquery_options.BigQueryOptions] = None) -> Session: return Session(context) diff --git a/bigframes/session/_io/bigquery.py b/bigframes/session/_io/bigquery.py index 3695fc98e8..de7b10b85e 100644 --- a/bigframes/session/_io/bigquery.py +++ b/bigframes/session/_io/bigquery.py @@ -20,11 +20,17 @@ import itertools import textwrap import types -from typing import Dict, Iterable, Optional, Sequence, Union +from typing import Dict, Iterable, Optional, Sequence, Tuple, Union import uuid +import google.api_core.exceptions import google.cloud.bigquery as bigquery +import bigframes +from bigframes.core import log_adapter +import bigframes.formatting_helpers as formatting_helpers +import bigframes.session._io.bigquery as bigframes_io + IO_ORDERING_ID = "bqdf_row_nums" MAX_LABELS_COUNT = 64 TEMP_TABLE_PREFIX = "bqdf{date}_{random_id}" @@ -207,3 +213,45 @@ def format_option(key: str, value: Union[bool, str]) -> str: if isinstance(value, bool): return f"{key}=true" if value else f"{key}=false" return f"{key}={repr(value)}" + + +def _prepare_job_config( + job_config: Optional[bigquery.QueryJobConfig] = None, +) -> bigquery.QueryJobConfig: + if job_config is None: + job_config = bigquery.QueryJobConfig() + if bigframes.options.compute.maximum_bytes_billed is not None: + job_config.maximum_bytes_billed = bigframes.options.compute.maximum_bytes_billed + return job_config + + +def start_query_with_client( + bq_client: bigquery.Client, + sql: str, + job_config: Optional[bigquery.job.QueryJobConfig] = None, + max_results: Optional[int] = None, +) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: + """ + Starts query job and waits for results. + """ + job_config = _prepare_job_config(job_config) + api_methods = log_adapter.get_and_reset_api_methods() + job_config.labels = bigframes_io.create_job_configs_labels( + job_configs_labels=job_config.labels, api_methods=api_methods + ) + + try: + query_job = bq_client.query(sql, job_config=job_config) + except google.api_core.exceptions.Forbidden as ex: + if "Drive credentials" in ex.message: + ex.message += "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions." + raise + + opts = bigframes.options.display + if opts.progress_bar is not None and not query_job.configuration.dry_run: + results_iterator = formatting_helpers.wait_for_query_job( + query_job, max_results, opts.progress_bar + ) + else: + results_iterator = query_job.result(max_results=max_results) + return results_iterator, query_job From f1c8b00694420e5199f18484178f4206885a949e Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Thu, 7 Mar 2024 03:32:06 +0000 Subject: [PATCH 10/16] revert bqmlclient, pass cmek through call time job config --- bigframes/ml/core.py | 7 +-- bigframes/session/__init__.py | 71 ++++++++++++++++++++++----- bigframes/session/_io/bigquery.py | 13 +---- bigframes/session/clients.py | 18 ------- tests/system/small/test_encryption.py | 40 ++++++--------- 5 files changed, 76 insertions(+), 73 deletions(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 9793208e41..90b26cd067 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -212,11 +212,8 @@ def principal_component_info(self) -> bpd.DataFrame: return self._session.read_gbq(sql) def copy(self, new_model_name: str, replace: bool = False) -> BqmlModel: - job_config = bigquery.job.CopyJobConfig( - destination_encryption_configuration=bigquery.EncryptionConfiguration( - kms_key_name=self._session._bq_kms_key_name - ) - ) + job_config = self._session._prepare_copy_job_config() + if replace: job_config.write_disposition = "WRITE_TRUNCATE" diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 712c8c9dc9..bd19404648 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -206,10 +206,6 @@ def __init__( def bqclient(self): return self._clients_provider.bqclient - @property - def bqmlclient(self): - return self._clients_provider.bqmlclient - @property def bqconnectionclient(self): return self._clients_provider.bqconnectionclient @@ -944,6 +940,8 @@ def _read_pandas_load_job( pandas_dataframe_copy.columns = pandas.Index(new_col_ids) pandas_dataframe_copy[ordering_col] = np.arange(pandas_dataframe_copy.shape[0]) + job_config = self._prepare_load_job_config() + # Specify the datetime dtypes, which is auto-detected as timestamp types. schema: list[bigquery.SchemaField] = [] for column, dtype in zip(pandas_dataframe.columns, pandas_dataframe.dtypes): @@ -951,12 +949,12 @@ def _read_pandas_load_job( schema.append( bigquery.SchemaField(column, bigquery.enums.SqlTypeNames.DATETIME) ) + job_config.schema = schema # Clustering probably not needed anyways as pandas tables are small cluster_cols = [ordering_col] - - job_config = bigquery.LoadJobConfig(schema=schema) job_config.clustering_fields = cluster_cols + job_config.labels = {"bigframes-api": api_name} load_table_destination = bigframes_io.random_table(self._anonymous_dataset) @@ -1076,7 +1074,7 @@ def read_csv( f"{constants.FEEDBACK_LINK}" ) - job_config = bigquery.LoadJobConfig() + job_config = self._prepare_load_job_config() job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED job_config.source_format = bigquery.SourceFormat.CSV job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY @@ -1151,7 +1149,7 @@ def read_parquet( # job anyway. table = bigframes_io.random_table(self._anonymous_dataset) - job_config = bigquery.LoadJobConfig() + job_config = self._prepare_load_job_config() job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED job_config.source_format = bigquery.SourceFormat.PARQUET job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY @@ -1196,7 +1194,7 @@ def read_json( "'lines' keyword is only valid when 'orient' is 'records'." ) - job_config = bigquery.LoadJobConfig() + job_config = self._prepare_load_job_config() job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED job_config.source_format = bigquery.SourceFormat.NEWLINE_DELIMITED_JSON job_config.write_disposition = bigquery.WriteDisposition.WRITE_EMPTY @@ -1520,6 +1518,50 @@ def read_gbq_function( session=self, ) + def _prepare_query_job_config( + self, + job_config: Optional[bigquery.QueryJobConfig] = None, + ) -> bigquery.QueryJobConfig: + if job_config is None: + job_config = bigquery.QueryJobConfig() + else: + # Create a copy so that we don't mutate the original config passed + job_config = bigquery.QueryJobConfig.from_api_repr(job_config.to_api_repr()) + + if bigframes.options.compute.maximum_bytes_billed is not None: + job_config.maximum_bytes_billed = ( + bigframes.options.compute.maximum_bytes_billed + ) + + if self._bq_kms_key_name: + job_config.destination_encryption_configuration = ( + bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name) + ) + + return job_config + + def _prepare_load_job_config(self) -> bigquery.LoadJobConfig: + # Create a copy so that we don't mutate the original config passed + job_config = bigquery.LoadJobConfig() + + if self._bq_kms_key_name: + job_config.destination_encryption_configuration = ( + bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name) + ) + + return job_config + + def _prepare_copy_job_config(self) -> bigquery.CopyJobConfig: + # Create a copy so that we don't mutate the original config passed + job_config = bigquery.CopyJobConfig() + + if self._bq_kms_key_name: + job_config.destination_encryption_configuration = ( + bigquery.EncryptionConfiguration(kms_key_name=self._bq_kms_key_name) + ) + + return job_config + def _start_query( self, sql: str, @@ -1529,6 +1571,7 @@ def _start_query( """ Starts BigQuery query job and waits for results. """ + job_config = self._prepare_query_job_config(job_config) return bigframes.session._io.bigquery.start_query_with_client( self.bqclient, sql, job_config, max_results ) @@ -1536,14 +1579,18 @@ def _start_query( def _start_query_bqml( self, sql: str, - job_config: Optional[bigquery.job.QueryJobConfig] = None, - max_results: Optional[int] = None, ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: """ Starts BigQuery ML query job and waits for results. """ + job_config = self._prepare_query_job_config() + + # BQML expects kms_key_name through OPTIONS and not through job config, + # so we must reset any encryption set in the job config + job_config.destination_encryption_configuration = None + return bigframes.session._io.bigquery.start_query_with_client( - self.bqmlclient, sql, job_config, max_results + self.bqclient, sql, job_config ) def _cache_with_cluster_cols( diff --git a/bigframes/session/_io/bigquery.py b/bigframes/session/_io/bigquery.py index de7b10b85e..67820bbbcb 100644 --- a/bigframes/session/_io/bigquery.py +++ b/bigframes/session/_io/bigquery.py @@ -215,26 +215,15 @@ def format_option(key: str, value: Union[bool, str]) -> str: return f"{key}={repr(value)}" -def _prepare_job_config( - job_config: Optional[bigquery.QueryJobConfig] = None, -) -> bigquery.QueryJobConfig: - if job_config is None: - job_config = bigquery.QueryJobConfig() - if bigframes.options.compute.maximum_bytes_billed is not None: - job_config.maximum_bytes_billed = bigframes.options.compute.maximum_bytes_billed - return job_config - - def start_query_with_client( bq_client: bigquery.Client, sql: str, - job_config: Optional[bigquery.job.QueryJobConfig] = None, + job_config: bigquery.job.QueryJobConfig, max_results: Optional[int] = None, ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: """ Starts query job and waits for results. """ - job_config = _prepare_job_config(job_config) api_methods = log_adapter.get_and_reset_api_methods() job_config.labels = bigframes_io.create_job_configs_labels( job_configs_labels=job_config.labels, api_methods=api_methods diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index 7e7456b07f..6eae18b2c8 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -137,27 +137,9 @@ def _create_bigquery_client(self): def bqclient(self): if not self._bqclient: self._bqclient = self._create_bigquery_client() - if self._bq_kms_key_name: - self._bqclient.default_query_job_config = bigquery.QueryJobConfig( - destination_encryption_configuration=bigquery.EncryptionConfiguration( - kms_key_name=self._bq_kms_key_name - ) - ) - self._bqclient.default_load_job_config = bigquery.LoadJobConfig( - destination_encryption_configuration=bigquery.EncryptionConfiguration( - kms_key_name=self._bq_kms_key_name - ) - ) return self._bqclient - @property - def bqmlclient(self): - if not self._bqmlclient: - self._bqmlclient = self._create_bigquery_client() - - return self._bqmlclient - @property def bqconnectionclient(self): if not self._bqconnectionclient: diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py index 3c53920c78..75f39f6ff1 100644 --- a/tests/system/small/test_encryption.py +++ b/tests/system/small/test_encryption.py @@ -30,12 +30,10 @@ def bq_cmek() -> str: See https://cloud.google.com/bigquery/docs/customer-managed-encryption for steps. """ - # NOTE: Set a key here for local testing. - # We are keeping this empty by default, in which case the dependent tests - # are skipped. - # TODO(shobs): Automate the tests depending on this fixture by either creating - # a static key in the test project or automating the key creation during the test. - return "" + # NOTE: This key is manually set up through the cloud console + # TODO(shobs): Automate the the key creation during the test. This will + # require extra IAM privileges for the test runner. + return "projects/bigframes-dev-perf/locations/global/keyRings/shobsKeyRing/cryptoKeys/shobsKey" @pytest.fixture(scope="module") @@ -64,26 +62,11 @@ def _assert_bq_table_is_encrypted( assert table.encryption_configuration.kms_key_name == cmek -def test_session_default_configs(bq_cmek, session_with_bq_cmek): - if not bq_cmek: - pytest.skip("no cmek set for testing") - - # Session should have cmek set in the default query and load job configs - assert ( - session_with_bq_cmek.bqclient.default_query_job_config.destination_encryption_configuration.kms_key_name - == bq_cmek - ) - assert ( - session_with_bq_cmek.bqclient.default_load_job_config.destination_encryption_configuration.kms_key_name - == bq_cmek - ) - - def test_session_query_job(bq_cmek, session_with_bq_cmek): if not bq_cmek: pytest.skip("no cmek set for testing") - query_job = session_with_bq_cmek.bqclient.query("SELECT 123") + _, query_job = session_with_bq_cmek._start_query("SELECT 123") query_job.result() assert query_job.destination_encryption_configuration.kms_key_name.startswith( @@ -103,12 +86,17 @@ def test_session_load_job(bq_cmek, session_with_bq_cmek): load_table = bigframes.session._io.bigquery.random_table( session_with_bq_cmek._anonymous_dataset ) + + df = pandas.DataFrame({"col0": [1, 2, 3]}) + load_job_config = session_with_bq_cmek._prepare_load_job_config() + load_job_config.schema = [ + bigquery.SchemaField(df.columns[0], bigquery.enums.SqlTypeNames.INT64) + ] + load_job = session_with_bq_cmek.bqclient.load_table_from_dataframe( - pandas.DataFrame({"col0": [1, 2, 3]}), + df, load_table, - job_config=bigquery.LoadJobConfig( - schema=[bigquery.SchemaField("col0", bigquery.enums.SqlTypeNames.INT64)] - ), + job_config=load_job_config, ) load_job.result() From 073f9893362611845a3fa2535aee1361093113ad Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Thu, 7 Mar 2024 03:39:12 +0000 Subject: [PATCH 11/16] revert bqmlclient unit test --- tests/unit/session/test_clients.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/session/test_clients.py b/tests/unit/session/test_clients.py index 74a721d6d8..30ba2f9091 100644 --- a/tests/unit/session/test_clients.py +++ b/tests/unit/session/test_clients.py @@ -93,7 +93,6 @@ def assert_clients_w_user_agent( provider: clients.ClientsProvider, expected_user_agent: str ): assert_constructed_w_user_agent(provider.bqclient, expected_user_agent) - assert_constructed_w_user_agent(provider.bqmlclient, expected_user_agent) assert_constructed_w_user_agent(provider.bqconnectionclient, expected_user_agent) assert_constructed_w_user_agent(provider.bqstoragereadclient, expected_user_agent) assert_constructed_w_user_agent(provider.cloudfunctionsclient, expected_user_agent) From 79a4b73db6678642df60acd7182635f4857bb5f0 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Thu, 7 Mar 2024 06:06:34 +0000 Subject: [PATCH 12/16] fix mypy failure --- bigframes/session/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 243ace37a5..f64090c013 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -1539,7 +1539,10 @@ def _prepare_query_job_config( job_config = bigquery.QueryJobConfig() else: # Create a copy so that we don't mutate the original config passed - job_config = bigquery.QueryJobConfig.from_api_repr(job_config.to_api_repr()) + job_config = typing.cast( + bigquery.QueryJobConfig, + bigquery.QueryJobConfig.from_api_repr(job_config.to_api_repr()), + ) if bigframes.options.compute.maximum_bytes_billed is not None: job_config.maximum_bytes_billed = ( From 313bb12a43e5a5432fa2538fc347e11a2c7016f9 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Thu, 7 Mar 2024 17:13:05 +0000 Subject: [PATCH 13/16] use better named key, disable use_query_cache in test --- tests/system/small/test_encryption.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/system/small/test_encryption.py b/tests/system/small/test_encryption.py index 75f39f6ff1..0ce9d881fd 100644 --- a/tests/system/small/test_encryption.py +++ b/tests/system/small/test_encryption.py @@ -33,7 +33,7 @@ def bq_cmek() -> str: # NOTE: This key is manually set up through the cloud console # TODO(shobs): Automate the the key creation during the test. This will # require extra IAM privileges for the test runner. - return "projects/bigframes-dev-perf/locations/global/keyRings/shobsKeyRing/cryptoKeys/shobsKey" + return "projects/bigframes-dev-perf/locations/us/keyRings/bigframesKeyRing/cryptoKeys/bigframesKey" @pytest.fixture(scope="module") @@ -66,7 +66,9 @@ def test_session_query_job(bq_cmek, session_with_bq_cmek): if not bq_cmek: pytest.skip("no cmek set for testing") - _, query_job = session_with_bq_cmek._start_query("SELECT 123") + _, query_job = session_with_bq_cmek._start_query( + "SELECT 123", job_config=bigquery.QueryJobConfig(use_query_cache=False) + ) query_job.result() assert query_job.destination_encryption_configuration.kms_key_name.startswith( From 4d4bb10de7efa9470e97f145256c3804850bb123 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 8 Mar 2024 00:33:15 +0000 Subject: [PATCH 14/16] rename bqml create model internal method --- bigframes/ml/core.py | 4 ++-- bigframes/session/__init__.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 90b26cd067..24997708fb 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -237,7 +237,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel: options={"vertex_ai_model_id": vertex_ai_model_id} ) # Register the model and wait it to finish - self._session._start_query_bqml(sql) + self._session._start_query_create_model(sql) self._model = self._session.bqclient.get_model(self.model_name) return self @@ -256,7 +256,7 @@ def _create_model_ref( def _create_model_with_sql(self, session: bigframes.Session, sql: str) -> BqmlModel: # fit the model, synchronously - _, job = session._start_query_bqml(sql) + _, job = session._start_query_create_model(sql) # real model path in the session specific hidden dataset and table prefix model_name_full = f"{job.destination.project}.{job.destination.dataset_id}.{job.destination.table_id}" diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index f64090c013..ff08eed764 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -1592,17 +1592,18 @@ def _start_query( self.bqclient, sql, job_config, max_results ) - def _start_query_bqml( + def _start_query_create_model( self, sql: str, ) -> Tuple[bigquery.table.RowIterator, bigquery.QueryJob]: """ - Starts BigQuery ML query job and waits for results. + Starts BigQuery ML CREATE MODEL query job and waits for results. """ job_config = self._prepare_query_job_config() # BQML expects kms_key_name through OPTIONS and not through job config, # so we must reset any encryption set in the job config + # https://cloud.google.com/bigquery/docs/customer-managed-encryption#encrypt-model job_config.destination_encryption_configuration = None return bigframes.session._io.bigquery.start_query_with_client( From 0eff3a6670f527736d1eb4fc3ca6b90bd60c8d68 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 8 Mar 2024 01:34:45 +0000 Subject: [PATCH 15/16] fix renamed methods's reference in unit tests --- tests/unit/ml/test_golden_sql.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index a3c9ed6e87..25e12d87bf 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -43,7 +43,7 @@ def mock_session(): mock_session._anonymous_dataset, TEMP_MODEL_ID.model_id ) ) - mock_session._start_query_bqml.return_value = (None, query_job) + mock_session._start_query_create_model.return_value = (None, query_job) return mock_session @@ -104,7 +104,7 @@ def test_linear_regression_default_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query_bqml.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -114,7 +114,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X, model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query_bqml.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -147,7 +147,7 @@ def test_logistic_regression_default_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query_bqml.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=True,\n auto_class_weights=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) @@ -161,7 +161,7 @@ def test_logistic_regression_params_fit( model._bqml_model_factory = bqml_model_factory model.fit(mock_X, mock_y) - mock_session._start_query_bqml.assert_called_once_with( + mock_session._start_query_create_model.assert_called_once_with( 'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LOGISTIC_REG",\n data_split_method="NO_SPLIT",\n fit_intercept=False,\n auto_class_weights=True,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql' ) From 185edd9f1c5d22dc536374886fbebdc9ed4f0b90 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Fri, 8 Mar 2024 03:02:41 +0000 Subject: [PATCH 16/16] remove stray bqmlclient variable --- bigframes/session/clients.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index 6eae18b2c8..7574aa4454 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -103,7 +103,6 @@ def __init__( # cloud clients initialized for lazy load self._bqclient = None - self._bqmlclient = None self._bqconnectionclient = None self._bqstoragereadclient = None self._cloudfunctionsclient = None