Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: re-enable use of parameters in gcs_to_bq which had been disabled #27961

Merged
merged 4 commits into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

from airflow import AirflowException
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.hooks.bigquery import (
BigQueryHook,
BigQueryJob,
_cleanse_time_partitioning,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger
Expand Down Expand Up @@ -390,8 +394,28 @@ def execute(self, context: Context):
"ignoreUnknownValues": self.ignore_unknown_values,
"allowQuotedNewlines": self.allow_quoted_newlines,
"encoding": self.encoding,
"allowJaggedRows": self.allow_jagged_rows,
"fieldDelimiter": self.field_delimiter,
"maxBadRecords": self.max_bad_records,
"quote": self.quote_character,
"schemaUpdateOptions": self.schema_update_options,
},
}
if self.cluster_fields:
self.configuration["load"].update({"clustering": {"fields": self.cluster_fields}})
time_partitioning = _cleanse_time_partitioning(
self.destination_project_dataset_table, self.time_partitioning
)
if time_partitioning:
self.configuration["load"].update({"timePartitioning": time_partitioning})
# fields that should only be set if defined
set_if_def = {
"quote": self.quote_character,
"destinationEncryptionConfiguration": self.encryption_configuration,
}
for k, v in set_if_def.items():
if v:
self.configuration["load"][k] = v
self.configuration = self._check_schema_fields(self.configuration)
try:
self.log.info("Executing: %s", self.configuration)
Expand Down
228 changes: 228 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def test_max_value_without_external_table_should_execute_successfully(self, hook
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -226,6 +231,11 @@ def test_max_value_should_throw_ex_when_query_returns_no_rows(self, hook):
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -335,6 +345,11 @@ def test_labels_without_external_table_should_execute_successfully(self, hook):
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -434,6 +449,11 @@ def test_description_without_external_table_should_execute_successfully(self, ho
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -535,6 +555,11 @@ def test_source_objs_as_list_without_external_table_should_execute_successfully(
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -632,6 +657,194 @@ def test_source_objs_as_string_without_external_table_should_execute_successfull
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
location=None,
job_id=pytest.real_job_id,
timeout=None,
retry=DEFAULT_RETRY,
nowait=True,
),
]

hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_all_fields_should_be_present(self, hook):
hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
hook.return_value.generate_job_id.return_value = pytest.real_job_id
hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_fields=SCHEMA_FIELDS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
write_disposition=WRITE_DISPOSITION,
external_table=False,
field_delimiter=";",
max_bad_records=13,
quote_character="|",
schema_update_options={"foo": "bar"},
allow_jagged_rows=True,
encryption_configuration={"bar": "baz"},
cluster_fields=["field_1", "field_2"],
)

operator.execute(context=MagicMock())

calls = [
call(
configuration={
"load": dict(
autodetect=True,
createDisposition="CREATE_IF_NEEDED",
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
destinationTableProperties={
"description": None,
"labels": None,
},
sourceFormat="CSV",
skipLeadingRows=None,
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
writeDisposition=WRITE_DISPOSITION,
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=True,
fieldDelimiter=";",
maxBadRecords=13,
quote="|",
schemaUpdateOptions={"foo": "bar"},
destinationEncryptionConfiguration={"bar": "baz"},
clustering={"fields": ["field_1", "field_2"]},
),
},
project_id=hook.return_value.project_id,
location=None,
job_id=pytest.real_job_id,
timeout=None,
retry=DEFAULT_RETRY,
nowait=True,
),
]

hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_date_partitioned_explicit_setting_should_be_found(self, hook):
hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
hook.return_value.generate_job_id.return_value = pytest.real_job_id
hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_fields=SCHEMA_FIELDS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
write_disposition=WRITE_DISPOSITION,
external_table=False,
time_partitioning={"type": "DAY"},
)

operator.execute(context=MagicMock())

calls = [
call(
configuration={
"load": dict(
autodetect=True,
createDisposition="CREATE_IF_NEEDED",
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
destinationTableProperties={
"description": None,
"labels": None,
},
sourceFormat="CSV",
skipLeadingRows=None,
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
writeDisposition=WRITE_DISPOSITION,
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
timePartitioning={"type": "DAY"},
),
},
project_id=hook.return_value.project_id,
location=None,
job_id=pytest.real_job_id,
timeout=None,
retry=DEFAULT_RETRY,
nowait=True,
),
]

hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_date_partitioned_implied_in_table_name_should_be_found(self, hook):
hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
hook.return_value.generate_job_id.return_value = pytest.real_job_id
hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_fields=SCHEMA_FIELDS,
destination_project_dataset_table=TEST_EXPLICIT_DEST + "$20221123",
write_disposition=WRITE_DISPOSITION,
external_table=False,
)

operator.execute(context=MagicMock())

calls = [
call(
configuration={
"load": dict(
autodetect=True,
createDisposition="CREATE_IF_NEEDED",
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
destinationTableProperties={
"description": None,
"labels": None,
},
sourceFormat="CSV",
skipLeadingRows=None,
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
writeDisposition=WRITE_DISPOSITION,
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
timePartitioning={"type": "DAY"},
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -830,6 +1043,11 @@ def test_schema_fields_scanner_without_external_table_should_execute_successfull
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=bq_hook.return_value.project_id,
Expand Down Expand Up @@ -1023,6 +1241,11 @@ def test_schema_fields_integer_scanner_without_external_table_should_execute_suc
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=bq_hook.return_value.project_id,
Expand Down Expand Up @@ -1087,6 +1310,11 @@ def test_schema_fields_without_external_table_should_execute_successfully(self,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS_INT},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down