Skip to content

Commit

Permalink
fix(ingest): bigquery-beta - handling complex types properly (#6062)
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es authored Sep 27, 2022
1 parent 59a2228 commit 3b9e979
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@
GlobalTagsClass,
TagAssociationClass,
)
from datahub.utilities.hive_schema_to_avro import (
HiveColumnToAvroConverter,
get_schema_fields_for_hive_column,
)
from datahub.utilities.mapping import Constants
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.registries.domain_registry import DomainRegistry
Expand Down Expand Up @@ -785,20 +789,35 @@ def gen_dataset_urn(self, dataset_name: str, project_id: str, table: str) -> str
)
return dataset_urn

def gen_schema_metadata(
self,
dataset_urn: str,
table: Union[BigqueryTable, BigqueryView],
dataset_name: str,
) -> MetadataWorkUnit:
schema_metadata = SchemaMetadata(
schemaName=dataset_name,
platform=make_data_platform_urn(self.platform),
version=0,
hash="",
platformSchema=MySqlDDL(tableSchema=""),
fields=[
SchemaField(
def gen_schema_fields(self, columns: List[BigqueryColumn]) -> List[SchemaField]:
schema_fields: List[SchemaField] = []

HiveColumnToAvroConverter._STRUCT_TYPE_SEPARATOR = " "
_COMPLEX_TYPE = re.compile("^(struct|array)")
last_id = -1
for col in columns:

if _COMPLEX_TYPE.match(col.data_type.lower()):
# If the we have seen the ordinal position that most probably means we already processed this complex type
if last_id != col.ordinal_position:
schema_fields.extend(
get_schema_fields_for_hive_column(
col.name, col.data_type.lower(), description=col.comment
)
)

# We have to add complex type comments to the correct level
if col.comment:
for idx, field in enumerate(schema_fields):
# Remove all the [version=2.0].[type=struct]. tags to get the field path
if (
re.sub(r"\[.*?\]\.", "", field.fieldPath, 0, re.MULTILINE)
== col.field_path
):
field.description = col.comment
schema_fields[idx] = field
else:
field = SchemaField(
fieldPath=col.name,
type=SchemaFieldDataType(
self.BIGQUERY_FIELD_TYPE_MAPPINGS.get(col.data_type, NullType)()
Expand All @@ -817,8 +836,24 @@ def gen_schema_metadata(
if col.is_partition_column
else GlobalTagsClass(tags=[]),
)
for col in table.columns
],
schema_fields.append(field)
last_id = col.ordinal_position
return schema_fields

def gen_schema_metadata(
self,
dataset_urn: str,
table: Union[BigqueryTable, BigqueryView],
dataset_name: str,
) -> MetadataWorkUnit:

schema_metadata = SchemaMetadata(
schemaName=dataset_name,
platform=make_data_platform_urn(self.platform),
version=0,
hash="",
platformSchema=MySqlDDL(tableSchema=""),
fields=self.gen_schema_fields(table.columns),
)
wu = wrap_aspect_as_workunit(
"dataset", dataset_urn, "schemaMetadata", schema_metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class BigqueryTableIdentifier:
table: str

invalid_chars: ClassVar[Set[str]] = {"$", "@"}
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[str] = "((.+)[_$])?(\\d{4,10})$"
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: ClassVar[str] = "((.+)[_$])?(\\d{8})$"

@staticmethod
def get_table_and_shard(table_name: str) -> Tuple[str, Optional[str]]:
Expand All @@ -101,17 +101,10 @@ def from_string_name(cls, table: str) -> "BigqueryTableIdentifier":
def raw_table_name(self):
return f"{self.project_id}.{self.dataset}.{self.table}"

@staticmethod
def _remove_suffix(input_string: str, suffixes: List[str]) -> str:
for suffix in suffixes:
if input_string.endswith(suffix):
return input_string[: -len(suffix)]
return input_string

def get_table_display_name(self) -> str:
shortened_table_name = self.table
# if table name ends in _* or * then we strip it as that represents a query on a sharded table
shortened_table_name = self._remove_suffix(shortened_table_name, ["_*", "*"])
shortened_table_name = re.sub("(_(.+)?\\*)|\\*$", "", shortened_table_name)

table_name, _ = self.get_table_and_shard(shortened_table_name)
if not table_name:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
logger: logging.Logger = logging.getLogger(__name__)


@dataclass
@dataclass(frozen=True, eq=True)
class BigqueryColumn:
name: str
ordinal_position: int
field_path: str
is_nullable: bool
is_partition_column: bool
data_type: str
Expand Down Expand Up @@ -175,6 +176,7 @@ class BigqueryQuery:
c.table_name as table_name,
c.column_name as column_name,
c.ordinal_position as ordinal_position,
cfp.field_path as field_path,
c.is_nullable as is_nullable,
c.data_type as data_type,
description as comment,
Expand All @@ -194,6 +196,7 @@ class BigqueryQuery:
c.table_name as table_name,
c.column_name as column_name,
c.ordinal_position as ordinal_position,
cfp.field_path as field_path,
c.is_nullable as is_nullable,
c.data_type as data_type,
c.is_hidden as is_hidden,
Expand Down Expand Up @@ -355,6 +358,7 @@ def get_columns_for_dataset(
BigqueryColumn(
name=column.column_name,
ordinal_position=column.ordinal_position,
field_path=column.field_path,
is_nullable=column.is_nullable == "YES",
data_type=column.data_type,
comment=column.comment,
Expand All @@ -379,6 +383,7 @@ def get_columns_for_table(
name=column.column_name,
ordinal_position=column.ordinal_position,
is_nullable=column.is_nullable == "YES",
field_path=column.field_path,
data_type=column.data_type,
comment=column.comment,
is_partition_column=column.is_partitioning_column == "YES",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,19 +268,12 @@ def is_temporary_table(self, prefix: str) -> bool:
# Temporary tables will have a dataset that begins with an underscore.
return self.dataset.startswith(prefix)

@staticmethod
def remove_suffix(input_string, suffix):
if suffix and input_string.endswith(suffix):
return input_string[: -len(suffix)]
return input_string

def remove_extras(self, sharded_table_regex: str) -> "BigQueryTableRef":
# Handle partitioned and sharded tables.
table_name: Optional[str] = None
shortened_table_name = self.table
# if table name ends in _* or * then we strip it as that represents a query on a sharded table
shortened_table_name = self.remove_suffix(shortened_table_name, "_*")
shortened_table_name = self.remove_suffix(shortened_table_name, "*")
shortened_table_name = re.sub("(_(.+)?\\*)|\\*$", "", shortened_table_name)

matches = re.match(sharded_table_regex, shortened_table_name)
if matches:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from datahub.configuration.common import ConfigModel, ConfigurationError

_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: str = "((.+)[_$])?(\\d{4,10})$"
_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX: str = "((.+)[_$])?(\\d{8})$"


class BigQueryBaseConfig(ConfigModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class HiveColumnToAvroConverter:
"float": "float",
"tinyint": "int",
"smallint": "int",
"int": "int",
"bigint": "long",
"varchar": "string",
"char": "string",
Expand All @@ -34,6 +33,8 @@ class HiveColumnToAvroConverter:

_FIXED_STRING = re.compile(r"(var)?char\(\s*(\d+)\s*\)")

_STRUCT_TYPE_SEPARATOR = ":"

@staticmethod
def _parse_datatype_string(
s: str, **kwargs: Any
Expand Down Expand Up @@ -103,7 +104,9 @@ def _parse_struct_fields_string(s: str, **kwargs: Any) -> Dict[str, object]:
parts = HiveColumnToAvroConverter._ignore_brackets_split(s, ",")
fields = []
for part in parts:
name_and_type = HiveColumnToAvroConverter._ignore_brackets_split(part, ":")
name_and_type = HiveColumnToAvroConverter._ignore_brackets_split(
part.strip(), HiveColumnToAvroConverter._STRUCT_TYPE_SEPARATOR
)
if len(name_and_type) != 2:
raise ValueError(
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_bigquery_ref_extra_removal():

table_ref = BigQueryTableRef("project-1234", "dataset-4567", "foo_2022")
new_table_ref = table_ref.remove_extras(_BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX)
assert new_table_ref.table == "foo"
assert new_table_ref.table == "foo_2022"
assert new_table_ref.project == table_ref.project
assert new_table_ref.dataset == table_ref.dataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_bigquery_table_sanitasitation():
new_table_ref = BigqueryTableIdentifier.from_string_name(
table_ref.table_identifier.get_table_name()
)
assert new_table_ref.table == "foo"
assert new_table_ref.table == "foo_2022"
assert new_table_ref.project_id == "project-1234"
assert new_table_ref.dataset == "dataset-4567"

Expand Down

0 comments on commit 3b9e979

Please sign in to comment.