Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update example_redshift and example_redshift_s3_transfers to use RedshiftDataHook instead of RedshiftSQLHook #40970

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class RedshiftToS3Operator(BaseOperator):
"unload_options",
"select_query",
"redshift_conn_id",
"redshift_data_api_kwargs",
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"select_query": "sql"}
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class S3ToRedshiftOperator(BaseOperator):
"copy_options",
"redshift_conn_id",
"method",
"redshift_data_api_kwargs",
"aws_conn_id",
)
template_ext: Sequence[str] = ()
Expand Down
11 changes: 0 additions & 11 deletions tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,6 @@ def test_table_unloading_role_arn(
assert extra["role_arn"] in unload_query
assert_equal_ignore_multiple_spaces(mock_run.call_args.args[0], unload_query)

def test_template_fields_overrides(self):
assert RedshiftToS3Operator.template_fields == (
"s3_bucket",
"s3_key",
"schema",
"table",
"unload_options",
"select_query",
"redshift_conn_id",
)

@pytest.mark.parametrize("param", ["sql", "parameters"])
def test_invalid_param_in_redshift_data_api_kwargs(self, param):
"""
Expand Down
13 changes: 0 additions & 13 deletions tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,19 +381,6 @@ def test_different_region(self, mock_run, mock_session, mock_connection, mock_ho
assert mock_run.call_count == 1
assert_equal_ignore_multiple_spaces(actual_copy_query, expected_copy_query)

def test_template_fields_overrides(self):
assert S3ToRedshiftOperator.template_fields == (
"s3_bucket",
"s3_key",
"schema",
"table",
"column_list",
"copy_options",
"redshift_conn_id",
"method",
"aws_conn_id",
)

def test_execute_unavailable_method(self):
"""
Test execute unavailable method
Expand Down
35 changes: 0 additions & 35 deletions tests/system/providers/amazon/aws/example_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@

from datetime import datetime

from airflow import settings
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
RedshiftCreateClusterSnapshotOperator,
Expand All @@ -36,7 +32,6 @@
)
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder

Expand All @@ -56,24 +51,6 @@
POLL_INTERVAL = 10


@task
def create_connection(conn_id_name: str, cluster_id: str):
redshift_hook = RedshiftHook()
cluster_endpoint = redshift_hook.get_conn().describe_clusters(ClusterIdentifier=cluster_id)["Clusters"][0]
conn = Connection(
conn_id=conn_id_name,
conn_type="redshift",
host=cluster_endpoint["Endpoint"]["Address"],
login=DB_LOGIN,
password=DB_PASS,
port=cluster_endpoint["Endpoint"]["Port"],
schema=cluster_endpoint["DBName"],
)
session = settings.Session()
session.add(conn)
session.commit()


with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
Expand All @@ -87,7 +64,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
cluster_subnet_group_name = test_context[CLUSTER_SUBNET_GROUP_KEY]
redshift_cluster_identifier = f"{env_id}-redshift-cluster"
redshift_cluster_snapshot_identifier = f"{env_id}-snapshot"
conn_id_name = f"{env_id}-conn-id"
sg_name = f"{env_id}-sg"

# [START howto_operator_redshift_cluster]
Expand Down Expand Up @@ -164,8 +140,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
timeout=60 * 30,
)

set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier)

# [START howto_operator_redshift_data]
create_table_redshift_data = RedshiftDataOperator(
task_id="create_table_redshift_data",
Expand Down Expand Up @@ -201,13 +175,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
wait_for_completion=True,
)

drop_table = SQLExecuteQueryOperator(
task_id="drop_table",
conn_id=conn_id_name,
sql="DROP TABLE IF EXISTS fruit",
trigger_rule=TriggerRule.ALL_DONE,
)

# [START howto_operator_redshift_delete_cluster]
delete_cluster = RedshiftDeleteClusterOperator(
task_id="delete_cluster",
Expand Down Expand Up @@ -236,10 +203,8 @@ def create_connection(conn_id_name: str, cluster_id: str):
wait_cluster_paused,
resume_cluster,
wait_cluster_available_after_resume,
set_up_connection,
create_table_redshift_data,
insert_data,
drop_table,
o-nikolas marked this conversation as resolved.
Show resolved Hide resolved
delete_cluster_snapshot,
delete_cluster,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@

from datetime import datetime

from airflow import settings
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
RedshiftDeleteClusterOperator,
Expand Down Expand Up @@ -75,24 +71,6 @@
DATA = "0, 'Airflow', 'testing'"


@task
def create_connection(conn_id_name: str, cluster_id: str):
redshift_hook = RedshiftHook()
cluster_endpoint = redshift_hook.get_conn().describe_clusters(ClusterIdentifier=cluster_id)["Clusters"][0]
conn = Connection(
conn_id=conn_id_name,
conn_type="redshift",
host=cluster_endpoint["Endpoint"]["Address"],
login=DB_LOGIN,
password=DB_PASS,
port=cluster_endpoint["Endpoint"]["Port"],
schema=cluster_endpoint["DBName"],
)
session = settings.Session()
session.add(conn)
session.commit()


with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
Expand All @@ -105,7 +83,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
security_group_id = test_context[SECURITY_GROUP_KEY]
cluster_subnet_group_name = test_context[CLUSTER_SUBNET_GROUP_KEY]
redshift_cluster_identifier = f"{env_id}-redshift-cluster"
conn_id_name = f"{env_id}-conn-id"
sg_name = f"{env_id}-sg"
bucket_name = f"{env_id}-bucket"

Expand Down Expand Up @@ -134,8 +111,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
timeout=60 * 30,
)

set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier)

create_object = S3CreateObjectOperator(
task_id="create_object",
s3_bucket=bucket_name,
Expand Down Expand Up @@ -165,7 +140,12 @@ def create_connection(conn_id_name: str, cluster_id: str):
# [START howto_transfer_redshift_to_s3]
transfer_redshift_to_s3 = RedshiftToS3Operator(
task_id="transfer_redshift_to_s3",
redshift_conn_id=conn_id_name,
redshift_data_api_kwargs={
"database": DB_NAME,
"cluster_identifier": redshift_cluster_identifier,
"db_user": DB_LOGIN,
"wait_for_completion": True,
},
s3_bucket=bucket_name,
s3_key=S3_KEY,
schema="PUBLIC",
Expand All @@ -182,7 +162,12 @@ def create_connection(conn_id_name: str, cluster_id: str):
# [START howto_transfer_s3_to_redshift]
transfer_s3_to_redshift = S3ToRedshiftOperator(
task_id="transfer_s3_to_redshift",
redshift_conn_id=conn_id_name,
redshift_data_api_kwargs={
"database": DB_NAME,
"cluster_identifier": redshift_cluster_identifier,
"db_user": DB_LOGIN,
"wait_for_completion": True,
},
s3_bucket=bucket_name,
s3_key=S3_KEY_2,
schema="PUBLIC",
Expand All @@ -194,7 +179,12 @@ def create_connection(conn_id_name: str, cluster_id: str):
# [START howto_transfer_s3_to_redshift_multiple_keys]
transfer_s3_to_redshift_multiple = S3ToRedshiftOperator(
task_id="transfer_s3_to_redshift_multiple",
redshift_conn_id=conn_id_name,
redshift_data_api_kwargs={
"database": DB_NAME,
"cluster_identifier": redshift_cluster_identifier,
"db_user": DB_LOGIN,
"wait_for_completion": True,
},
s3_bucket=bucket_name,
s3_key=S3_KEY_PREFIX,
schema="PUBLIC",
Expand Down Expand Up @@ -231,7 +221,6 @@ def create_connection(conn_id_name: str, cluster_id: str):
create_bucket,
create_cluster,
wait_cluster_available,
set_up_connection,
create_object,
create_table_redshift_data,
insert_data,
Expand Down