Skip to content

Commit

Permalink
Add RedshiftDataHook (#19137)
Browse files Browse the repository at this point in the history
Use the AWS `redshift-data` API to interact with AWS redshift clusters

Co-authored-by: john-jac <[email protected]>
  • Loading branch information
john-jac and john-jac authored Mar 2, 2022
1 parent ba79adb commit af79df6
Show file tree
Hide file tree
Showing 9 changed files with 501 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import datetime, timedelta
from os import getenv

from airflow.decorators import dag, task
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator

# [START howto_operator_redshift_data_env_variables]
REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "test-cluster")
REDSHIFT_DATABASE = getenv("REDSHIFT_DATABASE", "test-database")
REDSHIFT_DATABASE_USER = getenv("REDSHIFT_DATABASE_USER", "awsuser")
# [END howto_operator_redshift_data_env_variables]

REDSHIFT_QUERY = """
SELECT table_schema,
table_name
FROM information_schema.tables
WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
AND table_type = 'BASE TABLE'
ORDER BY table_schema,
table_name;
"""
POLL_INTERVAL = 10


# [START howto_redshift_data]
@dag(
dag_id='example_redshift_data',
schedule_interval=None,
start_date=datetime(2021, 1, 1),
dagrun_timeout=timedelta(minutes=60),
tags=['example'],
catchup=False,
)
def example_redshift_data():
@task(task_id="output_results")
def output_results_fn(id):
"""This is a python decorator task that returns a Redshift query"""
hook = RedshiftDataHook()

resp = hook.get_statement_result(
id=id,
)
print(resp)
return resp

# Run a SQL statement and wait for completion
redshift_query = RedshiftDataOperator(
task_id='redshift_query',
cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
database=REDSHIFT_DATABASE,
db_user=REDSHIFT_DATABASE_USER,
sql=REDSHIFT_QUERY,
poll_interval=POLL_INTERVAL,
await_result=True,
)

# Using a task-decorated function to output the list of tables in a Redshift cluster
output_results_fn(redshift_query.output)


example_redshift_data_dag = example_redshift_data()
# [END howto_redshift_data]
52 changes: 52 additions & 0 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient


class RedshiftDataHook(AwsBaseHook):
"""
Interact with AWS Redshift Data, using the boto3 library
Hook attribute `conn` has all methods that listed in documentation
.. seealso::
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-data.html
- https://docs.aws.amazon.com/redshift-data/latest/APIReference/Welcome.html
Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift-data"
super().__init__(*args, **kwargs)

@property
def conn(self) -> 'RedshiftDataAPIServiceClient':
"""Get the underlying boto3 RedshiftDataAPIService client (cached)"""
return super().conn
159 changes: 159 additions & 0 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from time import sleep
from typing import TYPE_CHECKING, Optional

if sys.version_info >= (3, 8):
from functools import cached_property
else:
from cached_property import cached_property

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class RedshiftDataOperator(BaseOperator):
"""
Executes SQL Statements against an Amazon Redshift cluster using Redshift Data
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:RedshiftDataOperator`
:param database: the name of the database
:param sql: the SQL statement text to run
:param cluster_identifier: unique identifier of a cluster
:param db_user: the database username
:param parameters: the parameters for the SQL statement
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
:param with_event: indicates whether to send an event to EventBridge
:param await_result: indicates whether to wait for a result, if True wait, if False don't wait
:param poll_interval: how often in seconds to check the query status
:param aws_conn_id: aws connection to use
:param region: aws region to use
"""

template_fields = (
'cluster_identifier',
'database',
'sql',
'db_user',
'parameters',
'statement_name',
'aws_conn_id',
'region',
)
template_ext = ('.sql',)
template_fields_renderers = {'sql': 'sql'}

def __init__(
self,
database: str,
sql: str,
cluster_identifier: Optional[str] = None,
db_user: Optional[str] = None,
parameters: Optional[list] = None,
secret_arn: Optional[str] = None,
statement_name: Optional[str] = None,
with_event: bool = False,
await_result: bool = True,
poll_interval: int = 10,
aws_conn_id: str = 'aws_default',
region: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.database = database
self.sql = sql
self.cluster_identifier = cluster_identifier
self.db_user = db_user
self.parameters = parameters
self.secret_arn = secret_arn
self.statement_name = statement_name
self.with_event = with_event
self.await_result = await_result
if poll_interval > 0:
self.poll_interval = poll_interval
else:
self.log.warning(
"Invalid poll_interval:",
poll_interval,
)
self.aws_conn_id = aws_conn_id
self.region = region
self.statement_id = None

@cached_property
def hook(self) -> RedshiftDataHook:
"""Create and return an RedshiftDataHook."""
return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region)

def execute_query(self):
resp = self.hook.conn.execute_statement(
ClusterIdentifier=self.cluster_identifier,
Database=self.database,
DbUser=self.db_user,
Sql=self.sql,
Parameters=self.parameters,
SecretArn=self.secret_arn,
StatementName=self.statement_name,
WithEvent=self.with_event,
)
return resp['Id']

def wait_for_results(self, statement_id):
while True:
self.log.info("Polling statement %s", statement_id)
resp = self.hook.conn.describe_statement(
Id=statement_id,
)
status = resp['Status']
if status == 'FINISHED':
return status
elif status == 'FAILED' or status == 'ABORTED':
raise ValueError(f"Statement {statement_id!r} terminated with status {status}.")
else:
self.log.info(f"Query {status}")
sleep(self.poll_interval)

def execute(self, context: 'Context') -> None:
"""Execute a statement against Amazon Redshift"""
self.log.info(f"Executing statement: {self.sql}")

self.statement_id = self.execute_query()

if self.await_result:
self.wait_for_results(self.statement_id)

return self.statement_id

def on_kill(self) -> None:
"""Cancel the submitted redshift query"""
if self.statement_id:
self.log.info('Received a kill signal.')
self.log.info('Stopping Query with statementId - %s', self.statement_id)

try:
self.hook.conn.cancel_statement(Id=self.statement_id)
except Exception as ex:
self.log.error('Unable to cancel query. Exiting. %s', ex)
10 changes: 10 additions & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ integrations:
- /docs/apache-airflow-providers-amazon/operators/redshift_sql.rst
- /docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst
tags: [aws]
- integration-name: Amazon Redshift Data
external-doc-url: https://aws.amazon.com/redshift/
logo: /integration-logos/aws/[email protected]
how-to-guide:
- /docs/apache-airflow-providers-amazon/operators/redshift_data.rst
tags: [aws]
- integration-name: Amazon SageMaker
external-doc-url: https://aws.amazon.com/sagemaker/
logo: /integration-logos/aws/[email protected]
Expand Down Expand Up @@ -271,6 +277,7 @@ operators:
- airflow.providers.amazon.aws.operators.redshift
- airflow.providers.amazon.aws.operators.redshift_sql
- airflow.providers.amazon.aws.operators.redshift_cluster
- airflow.providers.amazon.aws.operators.redshift_data

sensors:
- integration-name: Amazon Athena
Expand Down Expand Up @@ -401,6 +408,7 @@ hooks:
- airflow.providers.amazon.aws.hooks.redshift
- airflow.providers.amazon.aws.hooks.redshift_sql
- airflow.providers.amazon.aws.hooks.redshift_cluster
- airflow.providers.amazon.aws.hooks.redshift_data
- integration-name: Amazon Simple Storage Service (S3)
python-modules:
- airflow.providers.amazon.aws.hooks.s3
Expand Down Expand Up @@ -505,6 +513,8 @@ connection-types:
connection-type: emr
- hook-class-name: airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook
connection-type: redshift
- hook-class-name: airflow.providers.amazon.aws.hooks.redshift.RedshiftDataHook
connection-type: aws

secrets-backends:
- airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend
Expand Down
25 changes: 13 additions & 12 deletions docs/apache-airflow-providers-amazon/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,19 @@ You can install this package on top of an existing Airflow 2.1+ installation via
PIP requirements
----------------

======================= ===================
PIP package Version required
======================= ===================
``apache-airflow`` ``>=2.1.0``
``boto3`` ``>=1.15.0,<2.0.0``
``jsonpath_ng`` ``>=1.5.3``
``pandas`` ``>=0.17.1, <1.4``
``redshift_connector`` ``~=2.0.888``
``sqlalchemy_redshift`` ``~=0.8.6``
``watchtower`` ``~=2.0.1``
``mypy-boto3-rds`` ``>=1.21.0``
======================= ===================
============================ ===================
PIP package Version required
============================ ===================
``apache-airflow`` ``>=2.1.0``
``boto3`` ``>=1.15.0,<2.0.0``
``jsonpath_ng`` ``>=1.5.3``
``pandas`` ``>=0.17.1, <1.4``
``redshift_connector`` ``~=2.0.888``
``sqlalchemy_redshift`` ``~=0.8.6``
``watchtower`` ``~=2.0.1``
``mypy-boto3-rds`` ``>=1.21.0``
``mypy-boto3-redshift-data`` ``>=1.21.0``
============================ ===================

Cross provider package dependencies
-----------------------------------
Expand Down
Loading

0 comments on commit af79df6

Please sign in to comment.