diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index b681b92d033b7..63c625be87a25 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -313,15 +313,17 @@ def execute(self, context: Context): self.source_objects if isinstance(self.source_objects, list) else [self.source_objects] ) source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects] - if not self.schema_fields: + + if not self.schema_fields and self.schema_object and self.source_format != "DATASTORE_BACKUP": gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) - if self.schema_object and self.source_format != "DATASTORE_BACKUP": - schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8")) - self.log.info("Autodetected fields from schema object: %s", schema_fields) + self.schema_fields = json.loads( + gcs_hook.download(self.schema_object_bucket, self.schema_object).decode("utf-8") + ) + self.log.info("Autodetected fields from schema object: %s", self.schema_fields) if self.external_table: self.log.info("Creating a new BigQuery table for storing data...") diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index f4b5f59f82d34..1a9356134c239 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import json import unittest from unittest import mock from unittest.mock import MagicMock, call @@ -51,6 +52,8 @@ {"name": "id", "type": "INTEGER", "mode": "NULLABLE"}, {"name": "name", "type": "STRING", "mode": "NULLABLE"}, ] +SCHEMA_BUCKET = "test-schema-bucket" +SCHEMA_OBJECT = "test/schema/schema.json" TEST_SOURCE_OBJECTS = ["test/objects/test.csv"] TEST_SOURCE_OBJECTS_AS_STRING = "test/objects/test.csv" LABELS = {"k1": "v1"} @@ -675,6 +678,117 @@ def test_source_objs_as_string_without_external_table_should_execute_successfull hook.return_value.insert_job.assert_has_calls(calls) + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook") + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook") + def test_schema_obj_external_table_should_execute_successfully(self, bq_hook, gcs_hook): + bq_hook.return_value.insert_job.side_effect = [ + MagicMock(job_id=pytest.real_job_id, error_result=False), + pytest.real_job_id, + ] + bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id + bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) + gcs_hook.return_value.download.return_value = bytes(json.dumps(SCHEMA_FIELDS), "utf-8") + operator = GCSToBigQueryOperator( + task_id=TASK_ID, + bucket=TEST_BUCKET, + source_objects=TEST_SOURCE_OBJECTS, + schema_object_bucket=SCHEMA_BUCKET, + schema_object=SCHEMA_OBJECT, + write_disposition=WRITE_DISPOSITION, + destination_project_dataset_table=TEST_EXPLICIT_DEST, + external_table=True, + ) + + operator.execute(context=MagicMock()) + + bq_hook.return_value.create_empty_table.assert_called_once_with( + table_resource={ + "tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE}, + "labels": None, + "description": None, + "externalDataConfiguration": { + "source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"], + "source_format": "CSV", + "maxBadRecords": 0, + "autodetect": True, + "compression": "NONE", + "csvOptions": { + "fieldDelimeter": ",", + "skipLeadingRows": None, + "quote": None, + "allowQuotedNewlines": False, + "allowJaggedRows": False, + }, + }, + "location": None, + "encryptionConfiguration": None, + "schema": {"fields": SCHEMA_FIELDS}, + } + ) + gcs_hook.return_value.download.assert_called_once_with(SCHEMA_BUCKET, SCHEMA_OBJECT) + + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook") + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook") + def test_schema_obj_without_external_table_should_execute_successfully(self, bq_hook, gcs_hook): + bq_hook.return_value.insert_job.side_effect = [ + MagicMock(job_id=pytest.real_job_id, error_result=False), + pytest.real_job_id, + ] + bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id + bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) + gcs_hook.return_value.download.return_value = bytes(json.dumps(SCHEMA_FIELDS), "utf-8") + + operator = GCSToBigQueryOperator( + task_id=TASK_ID, + bucket=TEST_BUCKET, + source_objects=TEST_SOURCE_OBJECTS, + schema_object_bucket=SCHEMA_BUCKET, + schema_object=SCHEMA_OBJECT, + destination_project_dataset_table=TEST_EXPLICIT_DEST, + write_disposition=WRITE_DISPOSITION, + external_table=False, + ) + + operator.execute(context=MagicMock()) + + calls = [ + call( + configuration={ + "load": dict( + autodetect=True, + createDisposition="CREATE_IF_NEEDED", + destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE}, + destinationTableProperties={ + "description": None, + "labels": None, + }, + sourceFormat="CSV", + skipLeadingRows=None, + sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"], + writeDisposition=WRITE_DISPOSITION, + ignoreUnknownValues=False, + allowQuotedNewlines=False, + encoding="UTF-8", + schema={"fields": SCHEMA_FIELDS}, + allowJaggedRows=False, + fieldDelimiter=",", + maxBadRecords=0, + quote=None, + schemaUpdateOptions=(), + ), + }, + project_id=bq_hook.return_value.project_id, + location=None, + job_id=pytest.real_job_id, + timeout=None, + retry=DEFAULT_RETRY, + nowait=True, + ), + ] + + bq_hook.return_value.insert_job.assert_has_calls(calls) + gcs_hook.return_value.download.assert_called_once_with(SCHEMA_BUCKET, SCHEMA_OBJECT) + @mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook") def test_all_fields_should_be_present(self, hook): hook.return_value.insert_job.side_effect = [