From 34b0c6c4e36a8a191bbadea6bd6e6652a2c2f2e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Artur=20Skar=C5=BCy=C5=84ski?= Date: Fri, 9 Aug 2024 16:12:41 +0200 Subject: [PATCH 1/6] add OpenLineage support to S3ToRedshiftOperator --- .../amazon/aws/transfers/s3_to_redshift.py | 86 ++++++++++- .../amazon/aws/utils/open_lineage.py | 136 ++++++++++++++++++ 2 files changed, 217 insertions(+), 5 deletions(-) create mode 100644 airflow/providers/amazon/aws/utils/open_lineage.py diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 161276b33cb0..6a22cd8b2fb8 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -134,15 +134,17 @@ def _build_copy_query( {copy_options}; """ + def _create_hook(self) -> RedshiftDataHook | RedshiftSQLHook: + """If redshift_data_api_kwargs are provided, create RedshiftDataHook. RedshiftSQLHook otherwise.""" + if self.redshift_data_api_kwargs: + return RedshiftDataHook(aws_conn_id=self.redshift_conn_id) + return RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) + def execute(self, context: Context) -> None: if self.method not in AVAILABLE_METHODS: raise AirflowException(f"Method not found! Available methods: {AVAILABLE_METHODS}") - redshift_hook: RedshiftDataHook | RedshiftSQLHook - if self.redshift_data_api_kwargs: - redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) - else: - redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) + redshift_hook = self._create_hook() conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None region_info = "" if conn and conn.extra_dejson.get("region", False): @@ -197,3 +199,77 @@ def execute(self, context: Context) -> None: else: redshift_hook.run(sql, autocommit=self.autocommit) self.log.info("COPY command complete...") + + def get_openlineage_facets_on_complete(self, task_instance): + """Implement on_complete as we will query destination table.""" + from pathlib import Path + + from airflow.providers.amazon.aws.utils.open_lineage import ( + get_facets_from_redshift_table, + get_identity_column_lineage_facet, + ) + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + Identifier, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + SymlinksDatasetFacet, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + redshift_hook = self._create_hook() + if isinstance(redshift_hook, RedshiftDataHook): + database = self.redshift_data_api_kwargs.get("database") + identifier = self.redshift_data_api_kwargs.get( + "cluster_identifier" + ) or self.redshift_data_api_kwargs.get("workgroup_name") + port = self.redshift_data_api_kwargs.get("port", "5439") + authority = f"{identifier}.{redshift_hook.region_name}:{port}" + else: + database = redshift_hook.conn.schema + authority = redshift_hook.get_openlineage_database_info(redshift_hook.conn).authority + + output_dataset_facets = get_facets_from_redshift_table( + redshift_hook, self.table, self.redshift_data_api_kwargs, self.schema + ) + + input_dataset_facets = {} + if not self.column_list: + # If column_list is not specified, then we know that input file matches columns of output table. + input_dataset_facets["schema"] = output_dataset_facets["schema"] + + dataset_name = self.s3_key + if "*" in dataset_name: + # If wildcard ("*") is used in s3 path, we want the name of dataset to be directory name, + # but we create a symlink to the full object path with wildcard. + input_dataset_facets["symlink"] = SymlinksDatasetFacet( + identifiers=[Identifier(namespace=f"s3://{self.s3_bucket}", name=dataset_name, type="file")] + ) + dataset_name = Path(dataset_name).parent.as_posix() + if dataset_name == ".": + # blob path does not have leading slash, but we need root dataset name to be "/" + dataset_name = "/" + + input_dataset = Dataset( + namespace=f"s3://{self.s3_bucket}", + name=dataset_name, + facets=input_dataset_facets, + ) + + output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet( + field_names=[field.name for field in output_dataset_facets["schema"].fields], + input_datasets=[input_dataset], + ) + + if self.method == "REPLACE": + output_dataset_facets["lifecycleStateChange"] = LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.OVERWRITE + ) + + output_dataset = Dataset( + namespace=f"redshift://{authority}", + name=f"{database}.{self.schema}.{self.table}", + facets=output_dataset_facets, + ) + + return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset]) diff --git a/airflow/providers/amazon/aws/utils/open_lineage.py b/airflow/providers/amazon/aws/utils/open_lineage.py new file mode 100644 index 000000000000..db472a3e46c5 --- /dev/null +++ b/airflow/providers/amazon/aws/utils/open_lineage.py @@ -0,0 +1,136 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + DocumentationDatasetFacet, + Fields, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook + + +def get_facets_from_redshift_table( + redshift_hook: RedshiftDataHook | RedshiftSQLHook, + table: str, + redshift_data_api_kwargs: dict, + schema: str = "public", +) -> dict[Any, Any]: + """ + Query redshift for table metadata. + + SchemaDatasetFacet and DocumentationDatasetFacet (if table has description) will be created. + """ + sql = f""" + SELECT + cols.column_name, + cols.data_type, + col_des.description as column_description, + tbl_des.description as table_description + FROM + information_schema.columns cols + LEFT JOIN + pg_catalog.pg_description col_des + ON + cols.ordinal_position = col_des.objsubid + AND col_des.objoid = (SELECT oid FROM pg_class WHERE relnamespace = + (SELECT oid FROM pg_namespace WHERE nspname = cols.table_schema) AND relname = cols.table_name) + LEFT JOIN + pg_catalog.pg_class tbl + ON + tbl.relname = cols.table_name + AND tbl.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = cols.table_schema) + LEFT JOIN + pg_catalog.pg_description tbl_des + ON + tbl.oid = tbl_des.objoid + AND tbl_des.objsubid = 0 + WHERE + cols.table_name = '{table}' + AND cols.table_schema = '{schema}'; + """ + if isinstance(redshift_hook, RedshiftSQLHook): + records = redshift_hook.get_records(sql) + if records: + table_description = records[0][-1] # Assuming the table description is the same for all rows + else: + table_description = None + documentation = DocumentationDatasetFacet(description=table_description or "") + table_schema = SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name=field[0], type=field[1], description=field[2]) + for field in records + ] + ) + else: + statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1, **redshift_data_api_kwargs) + response = redshift_hook.conn.get_statement_result(Id=statement_id) + + table_schema = SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields( + name=field[0]["stringValue"], + type=field[1]["stringValue"], + description=field[2].get("stringValue"), + ) + for field in response["Records"] + ] + ) + # Table description will be the same for all fields, so we retrieve it from first field. + documentation = DocumentationDatasetFacet( + description=response["Records"][0][3].get("stringValue") or "" + ) + + return {"schema": table_schema, "documentation": documentation} + + +def get_identity_column_lineage_facet( + field_names, + input_datasets, +) -> ColumnLineageDatasetFacet: + """ + Get column lineage facet. + + Simple lineage will be created, where each source column corresponds to single destination column + in each input dataset and there are no transformations made. + """ + if field_names and not input_datasets: + raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.") + + column_lineage_facet = ColumnLineageDatasetFacet( + fields={ + field: Fields( + inputFields=[ + InputField(namespace=dataset.namespace, name=dataset.name, field=field) + for dataset in input_datasets + ], + transformationType="IDENTITY", + transformationDescription="identical", + ) + for field in field_names + } + ) + return column_lineage_facet From ca72b13fb503660e2101fbc86707621dfadfc9a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Artur=20Skar=C5=BCy=C5=84ski?= Date: Tue, 13 Aug 2024 14:39:12 +0200 Subject: [PATCH 2/6] add tests --- .../aws/transfers/test_s3_to_redshift.py | 308 ++++++++++++++++++ 1 file changed, 308 insertions(+) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index cb5ef7fdb75b..dacfdd4ac01c 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -26,6 +26,7 @@ from airflow.exceptions import AirflowException from airflow.models.connection import Connection from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator +from airflow.providers.common.compat.openlineage.facet import LifecycleStateChange from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces @@ -496,3 +497,310 @@ def test_using_redshift_data_api(self, mock_rs, mock_run, mock_session, mock_con assert access_key in actual_copy_query assert secret_key in actual_copy_query assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + def test_get_openlineage_facets_on_complete_default( + self, mock_run, mock_session, mock_connection, mock_hook + ): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = mock.MagicMock( + schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} + ) + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + copy_options = "" + + op = S3ToRedshiftOperator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + copy_options=copy_options, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + op.execute(None) + + lineage = op.get_openlineage_facets_on_complete(None) + # Hook called two times - on operator execution, and on querying data in redshift to fetch schema + assert mock_run.call_count == 2 + + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.inputs[0].name == s3_key + assert lineage.outputs[0].name == f"database.{schema}.{table}" + assert lineage.outputs[0].namespace == "redshift://cluster.region:5439" + + assert lineage.outputs[0].facets.get("schema") is not None + assert lineage.outputs[0].facets.get("columnLineage") is not None + + assert lineage.inputs[0].facets.get("schema") is not None + # As method was not overwrite, there should be no lifecycleStateChange facet + assert "lifecycleStateChange" not in lineage.outputs[0].facets + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + def test_get_openlineage_facets_on_complete_replace( + self, mock_run, mock_session, mock_connection, mock_hook + ): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = mock.MagicMock( + schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} + ) + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + copy_options = "" + + op = S3ToRedshiftOperator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + copy_options=copy_options, + method="REPLACE", + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + op.execute(None) + + lineage = op.get_openlineage_facets_on_complete(None) + + assert ( + lineage.outputs[0].facets["lifecycleStateChange"].lifecycleStateChange + == LifecycleStateChange.OVERWRITE + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + def test_get_openlineage_facets_on_complete_column_list( + self, mock_run, mock_session, mock_connection, mock_hook + ): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = mock.MagicMock( + schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} + ) + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + copy_options = "" + + op = S3ToRedshiftOperator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + copy_options=copy_options, + column_list=["column1", "column2"], + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + op.execute(None) + + lineage = op.get_openlineage_facets_on_complete(None) + + assert lineage.outputs[0].facets.get("schema") is not None + assert lineage.inputs[0].facets.get("schema") is None + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name", + new_callable=mock.PropertyMock, + ) + def test_get_openlineage_facets_on_complete_using_redshift_data_api( + self, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook + ): + """ + Using the Redshift Data API instead of the SQL-based connection + """ + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_hook.return_value = Connection() + mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"} + mock_rs.describe_statement.return_value = {"Status": "FINISHED"} + + mock_rs_region.return_value = "region" + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + copy_options = "" + + # RS Data API params + database = "database" + cluster_identifier = "cluster" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + + op = S3ToRedshiftOperator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + copy_options=copy_options, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + redshift_data_api_kwargs=dict( + database=database, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + ), + ) + op.execute(None) + + lineage = op.get_openlineage_facets_on_complete(None) + + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.inputs[0].name == s3_key + assert lineage.outputs[0].name == f"database.{schema}.{table}" + assert lineage.outputs[0].namespace == "redshift://cluster.region:5439" + + assert lineage.outputs[0].facets.get("schema") is not None + assert lineage.outputs[0].facets.get("columnLineage") is not None + + assert lineage.inputs[0].facets.get("schema") is not None + # As method was not overwrite, there should be no lifecycleStateChange facet + assert "lifecycleStateChange" not in lineage.outputs[0].facets + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name", + new_callable=mock.PropertyMock, + ) + def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned( + self, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook + ): + """ + Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook return same lineage. + """ + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = mock.MagicMock( + schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} + ) + mock_hook.return_value = Connection() + mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"} + mock_rs.describe_statement.return_value = {"Status": "FINISHED"} + + mock_rs_region.return_value = "region" + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + copy_options = "" + + # RS Data API params + database = "database" + cluster_identifier = "cluster" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + + op_rs_data = S3ToRedshiftOperator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + copy_options=copy_options, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + redshift_data_api_kwargs=dict( + database=database, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + ), + ) + op_rs_data.execute(None) + rs_data_lineage = op_rs_data.get_openlineage_facets_on_complete(None) + + op_rs_sql = S3ToRedshiftOperator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + copy_options=copy_options, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + dag=None, + ) + op_rs_sql.execute(None) + rs_sql_lineage = op_rs_sql.get_openlineage_facets_on_complete(None) + + assert rs_sql_lineage.inputs == rs_data_lineage.inputs + assert len(rs_sql_lineage) == 1 + assert len(rs_data_lineage) == 1 + assert rs_sql_lineage.outputs[0].facets["schema"] == rs_data_lineage.outputs[0].facets["schema"] + assert ( + rs_sql_lineage.outputs[0].facets["columnLineage"] + == rs_data_lineage.outputs[0].facets["columnLineage"] + ) + assert rs_sql_lineage.outputs[0].name == rs_data_lineage.outputs[0].name + assert rs_sql_lineage.outputs[0].namespace == rs_data_lineage.outputs[0].namespace From 25e409153534f76eca8bd8f153d29f3844e4b060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Artur=20Skar=C5=BCy=C5=84ski?= <33717106+Artuz37@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:44:55 +0200 Subject: [PATCH 3/6] Update airflow/providers/amazon/aws/transfers/s3_to_redshift.py Co-authored-by: Vincent <97131062+vincbeck@users.noreply.github.com> --- airflow/providers/amazon/aws/transfers/s3_to_redshift.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 6a22cd8b2fb8..acae755e6d4d 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -221,8 +221,8 @@ def get_openlineage_facets_on_complete(self, task_instance): if isinstance(redshift_hook, RedshiftDataHook): database = self.redshift_data_api_kwargs.get("database") identifier = self.redshift_data_api_kwargs.get( - "cluster_identifier" - ) or self.redshift_data_api_kwargs.get("workgroup_name") + "cluster_identifier", self.redshift_data_api_kwargs.get("workgroup_name") + ) port = self.redshift_data_api_kwargs.get("port", "5439") authority = f"{identifier}.{redshift_hook.region_name}:{port}" else: From 99c0f16cefb9646a32210aa5e67f3bf5f7f88cba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Artur=20Skar=C5=BCy=C5=84ski?= Date: Wed, 14 Aug 2024 15:01:39 +0200 Subject: [PATCH 4/6] rename open_lineage.py to openlineage.py --- airflow/providers/amazon/aws/transfers/s3_to_redshift.py | 2 +- .../amazon/aws/utils/{open_lineage.py => openlineage.py} | 0 tests/providers/amazon/aws/transfers/test_s3_to_redshift.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) rename airflow/providers/amazon/aws/utils/{open_lineage.py => openlineage.py} (100%) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index acae755e6d4d..e73ca5882316 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -204,7 +204,7 @@ def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will query destination table.""" from pathlib import Path - from airflow.providers.amazon.aws.utils.open_lineage import ( + from airflow.providers.amazon.aws.utils.openlineage import ( get_facets_from_redshift_table, get_identity_column_lineage_facet, ) diff --git a/airflow/providers/amazon/aws/utils/open_lineage.py b/airflow/providers/amazon/aws/utils/openlineage.py similarity index 100% rename from airflow/providers/amazon/aws/utils/open_lineage.py rename to airflow/providers/amazon/aws/utils/openlineage.py diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index dacfdd4ac01c..f554ce869997 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -795,8 +795,8 @@ def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned( rs_sql_lineage = op_rs_sql.get_openlineage_facets_on_complete(None) assert rs_sql_lineage.inputs == rs_data_lineage.inputs - assert len(rs_sql_lineage) == 1 - assert len(rs_data_lineage) == 1 + assert len(rs_sql_lineage.outputs) == 1 + assert len(rs_data_lineage.outputs) == 1 assert rs_sql_lineage.outputs[0].facets["schema"] == rs_data_lineage.outputs[0].facets["schema"] assert ( rs_sql_lineage.outputs[0].facets["columnLineage"] From 3bebefdca99b47d8f8c22807588e2cd4d3e7d504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Artur=20Skar=C5=BCy=C5=84ski?= Date: Wed, 14 Aug 2024 15:14:22 +0200 Subject: [PATCH 5/6] remove isinstance occurences for hook type checking in s3ToRedshiftOperator --- .../amazon/aws/transfers/s3_to_redshift.py | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index e73ca5882316..653885b54111 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -121,6 +121,10 @@ def __init__( if arg in self.redshift_data_api_kwargs: raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") + @property + def use_redshift_data(self): + return bool(self.redshift_data_api_kwargs) + def _build_copy_query( self, copy_destination: str, credentials_block: str, region_info: str, copy_options: str ) -> str: @@ -134,17 +138,15 @@ def _build_copy_query( {copy_options}; """ - def _create_hook(self) -> RedshiftDataHook | RedshiftSQLHook: - """If redshift_data_api_kwargs are provided, create RedshiftDataHook. RedshiftSQLHook otherwise.""" - if self.redshift_data_api_kwargs: - return RedshiftDataHook(aws_conn_id=self.redshift_conn_id) - return RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) - def execute(self, context: Context) -> None: if self.method not in AVAILABLE_METHODS: raise AirflowException(f"Method not found! Available methods: {AVAILABLE_METHODS}") - redshift_hook = self._create_hook() + if self.use_redshift_data: + redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) + else: + redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) + conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None region_info = "" if conn and conn.extra_dejson.get("region", False): @@ -169,12 +171,12 @@ def execute(self, context: Context) -> None: if self.method == "REPLACE": sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"] elif self.method == "UPSERT": - if isinstance(redshift_hook, RedshiftDataHook): - keys = self.upsert_keys or redshift_hook.get_table_primary_key( + if self.use_redshift_data: + keys = self.upsert_keys or redshift_data_hook.get_table_primary_key( table=self.table, schema=self.schema, **self.redshift_data_api_kwargs ) else: - keys = self.upsert_keys or redshift_hook.get_table_primary_key(self.table, self.schema) + keys = self.upsert_keys or redshift_sql_hook.get_table_primary_key(self.table, self.schema) if not keys: raise AirflowException( f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'" @@ -194,10 +196,10 @@ def execute(self, context: Context) -> None: sql = copy_statement self.log.info("Executing COPY command...") - if isinstance(redshift_hook, RedshiftDataHook): - redshift_hook.execute_query(sql=sql, **self.redshift_data_api_kwargs) + if self.use_redshift_data: + redshift_data_hook.execute_query(sql=sql, **self.redshift_data_api_kwargs) else: - redshift_hook.run(sql, autocommit=self.autocommit) + redshift_sql_hook.run(sql, autocommit=self.autocommit) self.log.info("COPY command complete...") def get_openlineage_facets_on_complete(self, task_instance): @@ -217,21 +219,24 @@ def get_openlineage_facets_on_complete(self, task_instance): ) from airflow.providers.openlineage.extractors import OperatorLineage - redshift_hook = self._create_hook() - if isinstance(redshift_hook, RedshiftDataHook): + if self.use_redshift_data: + redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) database = self.redshift_data_api_kwargs.get("database") identifier = self.redshift_data_api_kwargs.get( "cluster_identifier", self.redshift_data_api_kwargs.get("workgroup_name") ) port = self.redshift_data_api_kwargs.get("port", "5439") - authority = f"{identifier}.{redshift_hook.region_name}:{port}" + authority = f"{identifier}.{redshift_data_hook.region_name}:{port}" + output_dataset_facets = get_facets_from_redshift_table( + redshift_data_hook, self.table, self.redshift_data_api_kwargs, self.schema + ) else: - database = redshift_hook.conn.schema - authority = redshift_hook.get_openlineage_database_info(redshift_hook.conn).authority - - output_dataset_facets = get_facets_from_redshift_table( - redshift_hook, self.table, self.redshift_data_api_kwargs, self.schema - ) + redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) + database = redshift_sql_hook.conn.schema + authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority + output_dataset_facets = get_facets_from_redshift_table( + redshift_sql_hook, self.table, self.redshift_data_api_kwargs, self.schema + ) input_dataset_facets = {} if not self.column_list: From 21ecdf977a76b31c64f9327ecc55329c4b388345 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Artur=20Skar=C5=BCy=C5=84ski?= Date: Fri, 16 Aug 2024 13:39:58 +0200 Subject: [PATCH 6/6] add tests to aws/utils/openlineage.py --- .../amazon/aws/utils/test_openlineage.py | 168 ++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 tests/providers/amazon/aws/utils/test_openlineage.py diff --git a/tests/providers/amazon/aws/utils/test_openlineage.py b/tests/providers/amazon/aws/utils/test_openlineage.py new file mode 100644 index 000000000000..b3e820b58185 --- /dev/null +++ b/tests/providers/amazon/aws/utils/test_openlineage.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook +from airflow.providers.amazon.aws.utils.openlineage import ( + get_facets_from_redshift_table, + get_identity_column_lineage_facet, +) +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + Fields, + InputField, +) + + +@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.get_records") +def test_get_facets_from_redshift_table_sql_hook(mock_get_records): + mock_get_records.return_value = [ + ("column1", "varchar", "Column 1 description", "Table description"), + ("column2", "int", "Column 2 description", "Table description"), + ] + + mock_hook = RedshiftSQLHook() + + result = get_facets_from_redshift_table( + redshift_hook=mock_hook, table="my_table", redshift_data_api_kwargs={} + ) + + assert result["documentation"].description == "Table description" + assert len(result["schema"].fields) == 2 + assert result["schema"].fields[0].name == "column1" + assert result["schema"].fields[0].type == "varchar" + assert result["schema"].fields[0].description == "Column 1 description" + + +@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") +@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") +def test_get_facets_from_redshift_table_data_hook(mock_connection, mock_execute_query): + mock_execute_query.return_value = "statement_id" + mock_connection.get_statement_result.return_value = { + "Records": [ + [ + {"stringValue": "column1"}, + {"stringValue": "varchar"}, + {"stringValue": "Column 1 description"}, + {"stringValue": "Table description"}, + ], + [ + {"stringValue": "column2"}, + {"stringValue": "int"}, + {"stringValue": "Column 2 description"}, + {"stringValue": "Table description"}, + ], + ] + } + + mock_hook = RedshiftDataHook() + + result = get_facets_from_redshift_table( + redshift_hook=mock_hook, table="my_table", redshift_data_api_kwargs={} + ) + + assert result["documentation"].description == "Table description" + assert len(result["schema"].fields) == 2 + assert result["schema"].fields[0].name == "column1" + assert result["schema"].fields[0].type == "varchar" + assert result["schema"].fields[0].description == "Column 1 description" + + +@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.get_records") +def test_get_facets_no_records_sql_hook(mock_get_records): + mock_get_records.return_value = [] + + mock_hook = RedshiftSQLHook() + + result = get_facets_from_redshift_table( + redshift_hook=mock_hook, table="my_table", redshift_data_api_kwargs={} + ) + + assert result["documentation"].description == "" + assert len(result["schema"].fields) == 0 + + +def test_get_identity_column_lineage_facet_multiple_input_datasets(): + field_names = ["field1", "field2"] + input_datasets = [ + Dataset(namespace="s3://first_bucket", name="dir1"), + Dataset(namespace="s3://second_bucket", name="dir2"), + ] + expected_facet = ColumnLineageDatasetFacet( + fields={ + "field1": Fields( + inputFields=[ + InputField( + namespace="s3://first_bucket", + name="dir1", + field="field1", + ), + InputField( + namespace="s3://second_bucket", + name="dir2", + field="field1", + ), + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + "field2": Fields( + inputFields=[ + InputField( + namespace="s3://first_bucket", + name="dir1", + field="field2", + ), + InputField( + namespace="s3://second_bucket", + name="dir2", + field="field2", + ), + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + } + ) + result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) + assert result == expected_facet + + +def test_get_identity_column_lineage_facet_no_field_names(): + field_names = [] + input_datasets = [ + Dataset(namespace="s3://first_bucket", name="dir1"), + Dataset(namespace="s3://second_bucket", name="dir2"), + ] + expected_facet = ColumnLineageDatasetFacet(fields={}) + result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) + assert result == expected_facet + + +def test_get_identity_column_lineage_facet_no_input_datasets(): + field_names = ["field1", "field2"] + input_datasets = [] + + with pytest.raises(ValueError): + get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets)