Skip to content

Commit

Permalink
Revert changes for determining schema fields in _check_schema_fields()
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Dec 20, 2022
1 parent 7e5bb81 commit 9ffa97f
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 583 deletions.
93 changes: 27 additions & 66 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
from typing import TYPE_CHECKING, Any, Sequence

from google.api_core.exceptions import Conflict
from google.api_core.exceptions import BadRequest, Conflict
from google.api_core.retry import Retry
from google.cloud.bigquery import (
DEFAULT_RETRY,
Expand Down Expand Up @@ -247,13 +247,13 @@ def __init__(
# BQ config
self.destination_project_dataset_table = destination_project_dataset_table
self.schema_fields = schema_fields
if source_format not in ALLOWED_FORMATS:
if source_format.upper() not in ALLOWED_FORMATS:
raise ValueError(
f"{source_format} is not a valid source format. "
f"Please use one of the following types: {ALLOWED_FORMATS}."
)
else:
self.source_format = source_format
self.source_format = source_format.upper()
self.compression = compression
self.create_disposition = create_disposition
self.skip_leading_rows = skip_leading_rows
Expand Down Expand Up @@ -336,20 +336,23 @@ def execute(self, context: Context):
)
self.source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects]

if not self.schema_fields and self.schema_object and self.source_format != "DATASTORE_BACKUP":
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
self.schema_fields = json.loads(
gcs_hook.download(self.schema_object_bucket, self.schema_object).decode("utf-8")
)
self.log.info("Autodetected fields from schema object: %s", self.schema_fields)
if not self.schema_fields:
if not self.schema_object and not self.autodetect:
raise AirflowException(
"Table schema was not found. Neither schema object nor schema fields were specified"
)
if self.schema_object and self.source_format != "DATASTORE_BACKUP":
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
self.schema_fields = json.loads(
gcs_hook.download(self.schema_object_bucket, self.schema_object).decode("utf-8")
)
self.log.info("Loaded fields from schema object: %s", self.schema_fields)
else:
self.schema_fields = None

if self.external_table:
self.log.info("Creating a new BigQuery table for storing data...")
Expand Down Expand Up @@ -469,8 +472,17 @@ def _find_max_value_in_column(self):
"schemaUpdateOptions": [],
}
}
job_id = hook.insert_job(configuration=self.configuration, project_id=hook.project_id)
rows = list(hook.get_job(job_id=job_id, location=self.location).result())
try:
job_id = hook.insert_job(configuration=self.configuration, project_id=hook.project_id)
rows = list(hook.get_job(job_id=job_id, location=self.location).result())
except BadRequest as e:
if "Unrecognized name:" in e.message:
raise AirflowException(
f"Could not determine MAX value in column {self.max_id_key} "
f"since the default value of 'string_field_n' was set by BQ"
)
else:
raise AirflowException(e.message)
if rows:
for row in rows:
max_id = row[0] if row[0] else 0
Expand All @@ -484,53 +496,6 @@ def _find_max_value_in_column(self):
else:
raise RuntimeError(f"The {select_command} returned no rows!")

def _check_schema_fields(self, table_resource):
"""
Helper method to detect schema fields if they were not specified by user and autodetect=True.
If source_objects were passed, method reads the second row in CSV file. If there is at least one digit
table_resurce is returned without changes so that BigQuery can determine schema_fields in the
next step.
If there are only characters, the first row with fields is used to construct schema_fields argument
with type 'STRING'. Table_resource is updated with new schema_fileds key and returned back to operator
:param table_resource: Configuration or table_resource dictionary
:return: table_resource: Updated table_resource dict with schema_fields
"""
if not self.schema_fields:
for source_object in self.source_objects:
if self.source_format == "CSV":
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
blob = gcs_hook.download(
bucket_name=self.schema_object_bucket,
object_name=source_object,
)
fields, values = [item.split(",") for item in blob.decode("utf-8").splitlines()][:2]
self.log.info("fields: %s", fields)
import re

if any(re.match(r"[\d\-\\.]+$", value) for value in values):
self.log.info("table_resource: %s", table_resource)
return table_resource
else:
schema_fields = []
for field in fields:
schema_fields.append({"name": field, "type": "STRING", "mode": "NULLABLE"})
self.schema_fields = schema_fields
if self.external_table:
table_resource["externalDataConfiguration"]["csvOptions"]["skipLeadingRows"] = 1
elif not self.external_table:
table_resource["load"]["skipLeadingRows"] = 1
else:
return table_resource
if self.external_table:
table_resource["schema"] = {"fields": self.schema_fields}
elif not self.external_table:
table_resource["load"]["schema"] = {"fields": self.schema_fields}
return table_resource

def _create_empty_table(self):
project_id, dataset_id, table_id = self.hook.split_tablename(
table_input=self.destination_project_dataset_table,
Expand Down Expand Up @@ -595,8 +560,6 @@ def _create_empty_table(self):
self.encryption_configuration
)
table_obj_api_repr = table.to_api_repr()
if not self.schema_fields and self.source_format == "CSV":
table_obj_api_repr = self._check_schema_fields(table_obj_api_repr)

self.log.info("Creating external table: %s", self.destination_project_dataset_table)
self.hook.create_empty_table(
Expand Down Expand Up @@ -649,8 +612,6 @@ def _use_existing_table(self):

if self.schema_fields:
self.configuration["load"]["schema"] = {"fields": self.schema_fields}
elif self.source_format == "CSV":
self.configuration = self._check_schema_fields(self.configuration)

if self.schema_update_options:
if self.write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
Expand Down
Loading

0 comments on commit 9ffa97f

Please sign in to comment.