Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add OpenLineage support for BigQuery Create Table operators #44783

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions providers/src/airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
ColumnLineageDatasetFacet,
DocumentationDatasetFacet,
Fields,
Identifier,
InputField,
RunFacet,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
SymlinksDatasetFacet,
)
from airflow.providers.google import __version__ as provider_version
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url

BIGQUERY_NAMESPACE = "bigquery"
BIGQUERY_URI = "bigquery"
Expand Down Expand Up @@ -113,6 +116,20 @@ def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]:
if table.description:
facets["documentation"] = DocumentationDatasetFacet(description=table.description)

if table.external_data_configuration:
symlinks = set()
for uri in table.external_data_configuration.source_uris:
if uri.startswith("gs://"):
bucket, blob = _parse_gcs_url(uri)
blob = extract_ds_name_from_gcs_path(blob)
symlinks.add((f"gs://{bucket}", blob))

facets["symlink"] = SymlinksDatasetFacet(
identifiers=[
Identifier(namespace=namespace, name=name, type="file")
for namespace, name in sorted(symlinks)
]
)
return facets


Expand Down
65 changes: 52 additions & 13 deletions providers/src/airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def execute(self, context: Context) -> None:

try:
self.log.info("Creating table")
table = bq_hook.create_empty_table(
self._table = bq_hook.create_empty_table(
project_id=self.project_id,
dataset_id=self.dataset_id,
table_id=self.table_id,
Expand All @@ -1382,12 +1382,15 @@ def execute(self, context: Context) -> None:
persist_kwargs = {
"context": context,
"task_instance": self,
"project_id": table.to_api_repr()["tableReference"]["projectId"],
"dataset_id": table.to_api_repr()["tableReference"]["datasetId"],
"table_id": table.to_api_repr()["tableReference"]["tableId"],
"project_id": self._table.to_api_repr()["tableReference"]["projectId"],
"dataset_id": self._table.to_api_repr()["tableReference"]["datasetId"],
"table_id": self._table.to_api_repr()["tableReference"]["tableId"],
}
self.log.info(
"Table %s.%s.%s created successfully", table.project, table.dataset_id, table.table_id
"Table %s.%s.%s created successfully",
self._table.project,
self._table.dataset_id,
self._table.table_id,
)
except Conflict:
error_msg = f"Table {self.dataset_id}.{self.table_id} already exists."
Expand All @@ -1407,6 +1410,24 @@ def execute(self, context: Context) -> None:

BigQueryTableLink.persist(**persist_kwargs)

def get_openlineage_facets_on_complete(self, task_instance):
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table,
)
from airflow.providers.openlineage.extractors import OperatorLineage

table_info = self._table.to_api_repr()["tableReference"]
table_id = ".".join((table_info["projectId"], table_info["datasetId"], table_info["tableId"]))
output_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=table_id,
facets=get_facets_from_bq_table(self._table),
)

return OperatorLineage(outputs=[output_dataset])


class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -1632,15 +1653,15 @@ def execute(self, context: Context) -> None:
impersonation_chain=self.impersonation_chain,
)
if self.table_resource:
table = bq_hook.create_empty_table(
self._table = bq_hook.create_empty_table(
table_resource=self.table_resource,
)
BigQueryTableLink.persist(
context=context,
task_instance=self,
dataset_id=table.to_api_repr()["tableReference"]["datasetId"],
project_id=table.to_api_repr()["tableReference"]["projectId"],
table_id=table.to_api_repr()["tableReference"]["tableId"],
dataset_id=self._table.to_api_repr()["tableReference"]["datasetId"],
project_id=self._table.to_api_repr()["tableReference"]["projectId"],
table_id=self._table.to_api_repr()["tableReference"]["tableId"],
)
return

Expand Down Expand Up @@ -1691,18 +1712,36 @@ def execute(self, context: Context) -> None:
"encryptionConfiguration": self.encryption_configuration,
}

table = bq_hook.create_empty_table(
self._table = bq_hook.create_empty_table(
table_resource=table_resource,
)

BigQueryTableLink.persist(
context=context,
task_instance=self,
dataset_id=table.to_api_repr()["tableReference"]["datasetId"],
project_id=table.to_api_repr()["tableReference"]["projectId"],
table_id=table.to_api_repr()["tableReference"]["tableId"],
dataset_id=self._table.to_api_repr()["tableReference"]["datasetId"],
project_id=self._table.to_api_repr()["tableReference"]["projectId"],
table_id=self._table.to_api_repr()["tableReference"]["tableId"],
)

def get_openlineage_facets_on_complete(self, task_instance):
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table,
)
from airflow.providers.openlineage.extractors import OperatorLineage

table_info = self._table.to_api_repr()["tableReference"]
table_id = ".".join((table_info["projectId"], table_info["datasetId"], table_info["tableId"]))
output_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=table_id,
facets=get_facets_from_bq_table(self._table),
)

return OperatorLineage(outputs=[output_dataset])


class BigQueryDeleteDatasetOperator(GoogleCloudBaseOperator):
"""
Expand Down
12 changes: 12 additions & 0 deletions providers/tests/google/cloud/openlineage/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
Dataset,
DocumentationDatasetFacet,
Fields,
Identifier,
InputField,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
SymlinksDatasetFacet,
)
from airflow.providers.google.cloud.openlineage.utils import (
extract_ds_name_from_gcs_path,
Expand All @@ -49,6 +51,10 @@
{"name": "field2", "type": "INTEGER"},
]
},
"externalDataConfiguration": {
"sourceFormat": "CSV",
"sourceUris": ["gs://bucket/path/to/files*", "gs://second_bucket/path/to/other/files*"],
},
}
TEST_TABLE: Table = Table.from_api_repr(TEST_TABLE_API_REPR)
TEST_EMPTY_TABLE_API_REPR = {
Expand Down Expand Up @@ -84,6 +90,12 @@ def test_get_facets_from_bq_table():
]
),
"documentation": DocumentationDatasetFacet(description="Table description."),
"symlink": SymlinksDatasetFacet(
identifiers=[
Identifier(namespace="gs://bucket", name="path/to", type="file"),
Identifier(namespace="gs://second_bucket", name="path/to/other", type="file"),
]
),
}
result = get_facets_from_bq_table(TEST_TABLE)
assert result == expected_facets
Expand Down
124 changes: 123 additions & 1 deletion providers/tests/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import pandas as pd
import pytest
from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter
from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter, Table
from google.cloud.exceptions import Conflict

from airflow.exceptions import (
Expand All @@ -36,11 +36,17 @@
TaskDeferred,
)
from airflow.providers.common.compat.openlineage.facet import (
DocumentationDatasetFacet,
ErrorMessageRunFacet,
ExternalQueryRunFacet,
Identifier,
InputDataset,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
SQLJobFacet,
SymlinksDatasetFacet,
)
from airflow.providers.google.cloud.openlineage.utils import BIGQUERY_NAMESPACE
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCheckOperator,
BigQueryColumnCheckOperator,
Expand Down Expand Up @@ -259,6 +265,63 @@ def test_create_existing_table(self, mock_hook, caplog, if_exists, is_conflict,
if log_msg is not None:
assert log_msg in caplog.text

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_get_openlineage_facets_on_complete(self, mock_hook):
schema_fields = [
{"name": "field1", "type": "STRING", "description": "field1 description"},
{"name": "field2", "type": "INTEGER"},
]
table_resource = {
"tableReference": {
"projectId": TEST_GCP_PROJECT_ID,
"datasetId": TEST_DATASET,
"tableId": TEST_TABLE_ID,
},
"description": "Table description.",
"schema": {"fields": schema_fields},
}
mock_hook.return_value.create_empty_table.return_value = Table.from_api_repr(table_resource)
operator = BigQueryCreateEmptyTableOperator(
task_id=TASK_ID,
dataset_id=TEST_DATASET,
project_id=TEST_GCP_PROJECT_ID,
table_id=TEST_TABLE_ID,
schema_fields=schema_fields,
)
operator.execute(context=MagicMock())

mock_hook.return_value.create_empty_table.assert_called_once_with(
dataset_id=TEST_DATASET,
project_id=TEST_GCP_PROJECT_ID,
table_id=TEST_TABLE_ID,
schema_fields=schema_fields,
time_partitioning={},
cluster_fields=None,
labels=None,
view=None,
materialized_view=None,
encryption_configuration=None,
table_resource=None,
exists_ok=False,
)

result = operator.get_openlineage_facets_on_complete(None)
assert not result.run_facets
assert not result.job_facets
assert not result.inputs
assert len(result.outputs) == 1
assert result.outputs[0].namespace == BIGQUERY_NAMESPACE
assert result.outputs[0].name == f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}"
assert result.outputs[0].facets == {
"schema": SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"),
SchemaDatasetFacetFields(name="field2", type="INTEGER"),
]
),
"documentation": DocumentationDatasetFacet(description="Table description."),
}


class TestBigQueryCreateExternalTableOperator:
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
Expand Down Expand Up @@ -344,6 +407,65 @@ def test_execute_with_parquet_format(self, mock_hook):
operator.execute(context=MagicMock())
mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_get_openlineage_facets_on_complete(self, mock_hook):
table_resource = {
"tableReference": {
"projectId": TEST_GCP_PROJECT_ID,
"datasetId": TEST_DATASET,
"tableId": TEST_TABLE_ID,
},
"description": "Table description.",
"schema": {
"fields": [
{"name": "field1", "type": "STRING", "description": "field1 description"},
{"name": "field2", "type": "INTEGER"},
]
},
"externalDataConfiguration": {
"sourceUris": [
f"gs://{TEST_GCS_BUCKET}/{source_object}" for source_object in TEST_GCS_CSV_DATA
],
"sourceFormat": TEST_SOURCE_CSV_FORMAT,
},
}
mock_hook.return_value.create_empty_table.return_value = Table.from_api_repr(table_resource)
operator = BigQueryCreateExternalTableOperator(
task_id=TASK_ID,
bucket=TEST_GCS_BUCKET,
source_objects=TEST_GCS_CSV_DATA,
table_resource=table_resource,
)

mock_hook.return_value.split_tablename.return_value = (
TEST_GCP_PROJECT_ID,
TEST_DATASET,
TEST_TABLE_ID,
)

operator.execute(context=MagicMock())
mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource)

result = operator.get_openlineage_facets_on_complete(None)
assert not result.run_facets
assert not result.job_facets
assert not result.inputs
assert len(result.outputs) == 1
assert result.outputs[0].namespace == BIGQUERY_NAMESPACE
assert result.outputs[0].name == f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}"
assert result.outputs[0].facets == {
"schema": SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"),
SchemaDatasetFacetFields(name="field2", type="INTEGER"),
]
),
"documentation": DocumentationDatasetFacet(description="Table description."),
"symlink": SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"gs://{TEST_GCS_BUCKET}", name="dir1", type="file")]
),
}


class TestBigQueryDeleteDatasetOperator:
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
Expand Down