Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openlineage s3 to redshift operator integration #41575

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 91 additions & 10 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def __init__(
if arg in self.redshift_data_api_kwargs:
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")

@property
def use_redshift_data(self):
return bool(self.redshift_data_api_kwargs)

def _build_copy_query(
self, copy_destination: str, credentials_block: str, region_info: str, copy_options: str
) -> str:
Expand All @@ -138,11 +142,11 @@ def execute(self, context: Context) -> None:
if self.method not in AVAILABLE_METHODS:
raise AirflowException(f"Method not found! Available methods: {AVAILABLE_METHODS}")

redshift_hook: RedshiftDataHook | RedshiftSQLHook
if self.redshift_data_api_kwargs:
redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
if self.use_redshift_data:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
else:
redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)

conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
region_info = ""
if conn and conn.extra_dejson.get("region", False):
Expand All @@ -167,12 +171,12 @@ def execute(self, context: Context) -> None:
if self.method == "REPLACE":
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"]
elif self.method == "UPSERT":
if isinstance(redshift_hook, RedshiftDataHook):
keys = self.upsert_keys or redshift_hook.get_table_primary_key(
if self.use_redshift_data:
keys = self.upsert_keys or redshift_data_hook.get_table_primary_key(
table=self.table, schema=self.schema, **self.redshift_data_api_kwargs
)
else:
keys = self.upsert_keys or redshift_hook.get_table_primary_key(self.table, self.schema)
keys = self.upsert_keys or redshift_sql_hook.get_table_primary_key(self.table, self.schema)
if not keys:
raise AirflowException(
f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'"
Expand All @@ -192,8 +196,85 @@ def execute(self, context: Context) -> None:
sql = copy_statement

self.log.info("Executing COPY command...")
if isinstance(redshift_hook, RedshiftDataHook):
redshift_hook.execute_query(sql=sql, **self.redshift_data_api_kwargs)
if self.use_redshift_data:
redshift_data_hook.execute_query(sql=sql, **self.redshift_data_api_kwargs)
else:
redshift_hook.run(sql, autocommit=self.autocommit)
redshift_sql_hook.run(sql, autocommit=self.autocommit)
self.log.info("COPY command complete...")

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will query destination table."""
from pathlib import Path

from airflow.providers.amazon.aws.utils.openlineage import (
get_facets_from_redshift_table,
get_identity_column_lineage_facet,
)
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
Identifier,
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
SymlinksDatasetFacet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if self.use_redshift_data:
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
database = self.redshift_data_api_kwargs.get("database")
identifier = self.redshift_data_api_kwargs.get(
"cluster_identifier", self.redshift_data_api_kwargs.get("workgroup_name")
)
port = self.redshift_data_api_kwargs.get("port", "5439")
authority = f"{identifier}.{redshift_data_hook.region_name}:{port}"
output_dataset_facets = get_facets_from_redshift_table(
redshift_data_hook, self.table, self.redshift_data_api_kwargs, self.schema
)
else:
redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
database = redshift_sql_hook.conn.schema
authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority
output_dataset_facets = get_facets_from_redshift_table(
redshift_sql_hook, self.table, self.redshift_data_api_kwargs, self.schema
)

input_dataset_facets = {}
if not self.column_list:
# If column_list is not specified, then we know that input file matches columns of output table.
input_dataset_facets["schema"] = output_dataset_facets["schema"]

dataset_name = self.s3_key
if "*" in dataset_name:
# If wildcard ("*") is used in s3 path, we want the name of dataset to be directory name,
# but we create a symlink to the full object path with wildcard.
input_dataset_facets["symlink"] = SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"s3://{self.s3_bucket}", name=dataset_name, type="file")]
)
dataset_name = Path(dataset_name).parent.as_posix()
if dataset_name == ".":
# blob path does not have leading slash, but we need root dataset name to be "/"
dataset_name = "/"

input_dataset = Dataset(
namespace=f"s3://{self.s3_bucket}",
name=dataset_name,
facets=input_dataset_facets,
)

output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet(
field_names=[field.name for field in output_dataset_facets["schema"].fields],
input_datasets=[input_dataset],
)

if self.method == "REPLACE":
output_dataset_facets["lifecycleStateChange"] = LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.OVERWRITE
)

output_dataset = Dataset(
namespace=f"redshift://{authority}",
name=f"{database}.{self.schema}.{self.table}",
facets=output_dataset_facets,
)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
136 changes: 136 additions & 0 deletions airflow/providers/amazon/aws/utils/openlineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.common.compat.openlineage.facet import (
ColumnLineageDatasetFacet,
DocumentationDatasetFacet,
Fields,
InputField,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
)

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook


def get_facets_from_redshift_table(
redshift_hook: RedshiftDataHook | RedshiftSQLHook,
table: str,
redshift_data_api_kwargs: dict,
schema: str = "public",
) -> dict[Any, Any]:
"""
Query redshift for table metadata.

SchemaDatasetFacet and DocumentationDatasetFacet (if table has description) will be created.
"""
sql = f"""
SELECT
cols.column_name,
cols.data_type,
col_des.description as column_description,
tbl_des.description as table_description
FROM
information_schema.columns cols
LEFT JOIN
pg_catalog.pg_description col_des
ON
cols.ordinal_position = col_des.objsubid
AND col_des.objoid = (SELECT oid FROM pg_class WHERE relnamespace =
(SELECT oid FROM pg_namespace WHERE nspname = cols.table_schema) AND relname = cols.table_name)
LEFT JOIN
pg_catalog.pg_class tbl
ON
tbl.relname = cols.table_name
AND tbl.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = cols.table_schema)
LEFT JOIN
pg_catalog.pg_description tbl_des
ON
tbl.oid = tbl_des.objoid
AND tbl_des.objsubid = 0
WHERE
cols.table_name = '{table}'
AND cols.table_schema = '{schema}';
"""
if isinstance(redshift_hook, RedshiftSQLHook):
records = redshift_hook.get_records(sql)
if records:
table_description = records[0][-1] # Assuming the table description is the same for all rows
else:
table_description = None
documentation = DocumentationDatasetFacet(description=table_description or "")
table_schema = SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(name=field[0], type=field[1], description=field[2])
for field in records
]
)
else:
statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1, **redshift_data_api_kwargs)
response = redshift_hook.conn.get_statement_result(Id=statement_id)

table_schema = SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(
name=field[0]["stringValue"],
type=field[1]["stringValue"],
description=field[2].get("stringValue"),
)
for field in response["Records"]
]
)
# Table description will be the same for all fields, so we retrieve it from first field.
documentation = DocumentationDatasetFacet(
description=response["Records"][0][3].get("stringValue") or ""
)

return {"schema": table_schema, "documentation": documentation}


def get_identity_column_lineage_facet(
field_names,
input_datasets,
) -> ColumnLineageDatasetFacet:
"""
Get column lineage facet.

Simple lineage will be created, where each source column corresponds to single destination column
in each input dataset and there are no transformations made.
"""
if field_names and not input_datasets:
raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.")

column_lineage_facet = ColumnLineageDatasetFacet(
fields={
field: Fields(
inputFields=[
InputField(namespace=dataset.namespace, name=dataset.name, field=field)
for dataset in input_datasets
],
transformationType="IDENTITY",
transformationDescription="identical",
)
for field in field_names
}
)
return column_lineage_facet
Loading