From 010068bdd71fb550fe9287c852ce63a42089033e Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 23 Jul 2024 16:22:05 -0400 Subject: [PATCH 1/3] Update `example_redshift` and `example_redshift_s3_transfers` to use `RedshiftDataHook` instead of `RedshiftSQLHook` --- .../amazon/aws/transfers/redshift_to_s3.py | 1 + .../amazon/aws/transfers/s3_to_redshift.py | 1 + .../providers/amazon/aws/example_redshift.py | 35 -------------- .../aws/example_redshift_s3_transfers.py | 47 +++++++------------ 4 files changed, 20 insertions(+), 64 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index f6aafeba5927..73578ea539b7 100644 --- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -84,6 +84,7 @@ class RedshiftToS3Operator(BaseOperator): "unload_options", "select_query", "redshift_conn_id", + "redshift_data_api_kwargs", ) template_ext: Sequence[str] = (".sql",) template_fields_renderers = {"select_query": "sql"} diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 6418c111e249..161276b33cb0 100644 --- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -77,6 +77,7 @@ class S3ToRedshiftOperator(BaseOperator): "copy_options", "redshift_conn_id", "method", + "redshift_data_api_kwargs", "aws_conn_id", ) template_ext: Sequence[str] = () diff --git a/tests/system/providers/amazon/aws/example_redshift.py b/tests/system/providers/amazon/aws/example_redshift.py index 84be4c702cd0..cc88811bef28 100644 --- a/tests/system/providers/amazon/aws/example_redshift.py +++ b/tests/system/providers/amazon/aws/example_redshift.py @@ -20,12 +20,8 @@ from datetime import datetime -from airflow import settings -from airflow.decorators import task -from airflow.models import Connection from airflow.models.baseoperator import chain from airflow.models.dag import DAG -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.operators.redshift_cluster import ( RedshiftCreateClusterOperator, RedshiftCreateClusterSnapshotOperator, @@ -36,7 +32,6 @@ ) from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor -from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder @@ -56,24 +51,6 @@ POLL_INTERVAL = 10 -@task -def create_connection(conn_id_name: str, cluster_id: str): - redshift_hook = RedshiftHook() - cluster_endpoint = redshift_hook.get_conn().describe_clusters(ClusterIdentifier=cluster_id)["Clusters"][0] - conn = Connection( - conn_id=conn_id_name, - conn_type="redshift", - host=cluster_endpoint["Endpoint"]["Address"], - login=DB_LOGIN, - password=DB_PASS, - port=cluster_endpoint["Endpoint"]["Port"], - schema=cluster_endpoint["DBName"], - ) - session = settings.Session() - session.add(conn) - session.commit() - - with DAG( dag_id=DAG_ID, start_date=datetime(2021, 1, 1), @@ -87,7 +64,6 @@ def create_connection(conn_id_name: str, cluster_id: str): cluster_subnet_group_name = test_context[CLUSTER_SUBNET_GROUP_KEY] redshift_cluster_identifier = f"{env_id}-redshift-cluster" redshift_cluster_snapshot_identifier = f"{env_id}-snapshot" - conn_id_name = f"{env_id}-conn-id" sg_name = f"{env_id}-sg" # [START howto_operator_redshift_cluster] @@ -164,8 +140,6 @@ def create_connection(conn_id_name: str, cluster_id: str): timeout=60 * 30, ) - set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier) - # [START howto_operator_redshift_data] create_table_redshift_data = RedshiftDataOperator( task_id="create_table_redshift_data", @@ -201,13 +175,6 @@ def create_connection(conn_id_name: str, cluster_id: str): wait_for_completion=True, ) - drop_table = SQLExecuteQueryOperator( - task_id="drop_table", - conn_id=conn_id_name, - sql="DROP TABLE IF EXISTS fruit", - trigger_rule=TriggerRule.ALL_DONE, - ) - # [START howto_operator_redshift_delete_cluster] delete_cluster = RedshiftDeleteClusterOperator( task_id="delete_cluster", @@ -236,10 +203,8 @@ def create_connection(conn_id_name: str, cluster_id: str): wait_cluster_paused, resume_cluster, wait_cluster_available_after_resume, - set_up_connection, create_table_redshift_data, insert_data, - drop_table, delete_cluster_snapshot, delete_cluster, ) diff --git a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py index 4fbf728fa834..069104619050 100644 --- a/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py +++ b/tests/system/providers/amazon/aws/example_redshift_s3_transfers.py @@ -18,12 +18,8 @@ from datetime import datetime -from airflow import settings -from airflow.decorators import task -from airflow.models import Connection from airflow.models.baseoperator import chain from airflow.models.dag import DAG -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.operators.redshift_cluster import ( RedshiftCreateClusterOperator, RedshiftDeleteClusterOperator, @@ -75,24 +71,6 @@ DATA = "0, 'Airflow', 'testing'" -@task -def create_connection(conn_id_name: str, cluster_id: str): - redshift_hook = RedshiftHook() - cluster_endpoint = redshift_hook.get_conn().describe_clusters(ClusterIdentifier=cluster_id)["Clusters"][0] - conn = Connection( - conn_id=conn_id_name, - conn_type="redshift", - host=cluster_endpoint["Endpoint"]["Address"], - login=DB_LOGIN, - password=DB_PASS, - port=cluster_endpoint["Endpoint"]["Port"], - schema=cluster_endpoint["DBName"], - ) - session = settings.Session() - session.add(conn) - session.commit() - - with DAG( dag_id=DAG_ID, start_date=datetime(2021, 1, 1), @@ -105,7 +83,6 @@ def create_connection(conn_id_name: str, cluster_id: str): security_group_id = test_context[SECURITY_GROUP_KEY] cluster_subnet_group_name = test_context[CLUSTER_SUBNET_GROUP_KEY] redshift_cluster_identifier = f"{env_id}-redshift-cluster" - conn_id_name = f"{env_id}-conn-id" sg_name = f"{env_id}-sg" bucket_name = f"{env_id}-bucket" @@ -134,8 +111,6 @@ def create_connection(conn_id_name: str, cluster_id: str): timeout=60 * 30, ) - set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier) - create_object = S3CreateObjectOperator( task_id="create_object", s3_bucket=bucket_name, @@ -165,7 +140,12 @@ def create_connection(conn_id_name: str, cluster_id: str): # [START howto_transfer_redshift_to_s3] transfer_redshift_to_s3 = RedshiftToS3Operator( task_id="transfer_redshift_to_s3", - redshift_conn_id=conn_id_name, + redshift_data_api_kwargs={ + "database": DB_NAME, + "cluster_identifier": redshift_cluster_identifier, + "db_user": DB_LOGIN, + "wait_for_completion": True, + }, s3_bucket=bucket_name, s3_key=S3_KEY, schema="PUBLIC", @@ -182,7 +162,12 @@ def create_connection(conn_id_name: str, cluster_id: str): # [START howto_transfer_s3_to_redshift] transfer_s3_to_redshift = S3ToRedshiftOperator( task_id="transfer_s3_to_redshift", - redshift_conn_id=conn_id_name, + redshift_data_api_kwargs={ + "database": DB_NAME, + "cluster_identifier": redshift_cluster_identifier, + "db_user": DB_LOGIN, + "wait_for_completion": True, + }, s3_bucket=bucket_name, s3_key=S3_KEY_2, schema="PUBLIC", @@ -194,7 +179,12 @@ def create_connection(conn_id_name: str, cluster_id: str): # [START howto_transfer_s3_to_redshift_multiple_keys] transfer_s3_to_redshift_multiple = S3ToRedshiftOperator( task_id="transfer_s3_to_redshift_multiple", - redshift_conn_id=conn_id_name, + redshift_data_api_kwargs={ + "database": DB_NAME, + "cluster_identifier": redshift_cluster_identifier, + "db_user": DB_LOGIN, + "wait_for_completion": True, + }, s3_bucket=bucket_name, s3_key=S3_KEY_PREFIX, schema="PUBLIC", @@ -231,7 +221,6 @@ def create_connection(conn_id_name: str, cluster_id: str): create_bucket, create_cluster, wait_cluster_available, - set_up_connection, create_object, create_table_redshift_data, insert_data, From a4ec710a4909d8e4ed0e7504eb79d6b183b4aa61 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 23 Jul 2024 16:44:47 -0400 Subject: [PATCH 2/3] Remove test --- .../amazon/aws/transfers/test_redshift_to_s3.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py index d025b4836f8a..d2af90a445e2 100644 --- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py @@ -364,17 +364,6 @@ def test_table_unloading_role_arn( assert extra["role_arn"] in unload_query assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query) - def test_template_fields_overrides(self): - assert RedshiftToS3Operator.template_fields == ( - "s3_bucket", - "s3_key", - "schema", - "table", - "unload_options", - "select_query", - "redshift_conn_id", - ) - @pytest.mark.parametrize("param", ["sql", "parameters"]) def test_invalid_param_in_redshift_data_api_kwargs(self, param): """ From 95273424bf5e3829d18aa36235efcac2f49c3251 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 23 Jul 2024 17:07:59 -0400 Subject: [PATCH 3/3] Remove test --- .../amazon/aws/transfers/test_s3_to_redshift.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py index 6e3cbb2a1ca5..cb5ef7fdb75b 100644 --- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py +++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py @@ -381,19 +381,6 @@ def test_different_region(self, mock_run, mock_session, mock_connection, mock_ho assert mock_run.call_count == 1 assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query) - def test_template_fields_overrides(self): - assert S3ToRedshiftOperator.template_fields == ( - "s3_bucket", - "s3_key", - "schema", - "table", - "column_list", - "copy_options", - "redshift_conn_id", - "method", - "aws_conn_id", - ) - def test_execute_unavailable_method(self): """ Test execute unavailable method