From af79df69b2fe4d71fa244087d5cfd031d809bc7e Mon Sep 17 00:00:00 2001 From: john-jac <75442233+john-jac@users.noreply.github.com> Date: Wed, 2 Mar 2022 12:18:33 -0800 Subject: [PATCH] Add RedshiftDataHook (#19137) Use the AWS `redshift-data` API to interact with AWS redshift clusters Co-authored-by: john-jac --- .../example_redshift_data_execute_sql.py | 80 +++++++++ .../amazon/aws/hooks/redshift_data.py | 52 ++++++ .../amazon/aws/operators/redshift_data.py | 159 ++++++++++++++++++ airflow/providers/amazon/provider.yaml | 10 ++ .../apache-airflow-providers-amazon/index.rst | 25 +-- .../operators/redshift_data.rst | 52 ++++++ setup.py | 1 + .../amazon/aws/hooks/test_redshift_data.py | 26 +++ .../aws/operators/test_redshift_data.py | 108 ++++++++++++ 9 files changed, 501 insertions(+), 12 deletions(-) create mode 100644 airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py create mode 100644 airflow/providers/amazon/aws/hooks/redshift_data.py create mode 100644 airflow/providers/amazon/aws/operators/redshift_data.py create mode 100644 docs/apache-airflow-providers-amazon/operators/redshift_data.rst create mode 100644 tests/providers/amazon/aws/hooks/test_redshift_data.py create mode 100644 tests/providers/amazon/aws/operators/test_redshift_data.py diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py b/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py new file mode 100644 index 0000000000000..4806d66f3f5aa --- /dev/null +++ b/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime, timedelta +from os import getenv + +from airflow.decorators import dag, task +from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator + +# [START howto_operator_redshift_data_env_variables] +REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "test-cluster") +REDSHIFT_DATABASE = getenv("REDSHIFT_DATABASE", "test-database") +REDSHIFT_DATABASE_USER = getenv("REDSHIFT_DATABASE_USER", "awsuser") +# [END howto_operator_redshift_data_env_variables] + +REDSHIFT_QUERY = """ +SELECT table_schema, + table_name +FROM information_schema.tables +WHERE table_schema NOT IN ('information_schema', 'pg_catalog') + AND table_type = 'BASE TABLE' +ORDER BY table_schema, + table_name; + """ +POLL_INTERVAL = 10 + + +# [START howto_redshift_data] +@dag( + dag_id='example_redshift_data', + schedule_interval=None, + start_date=datetime(2021, 1, 1), + dagrun_timeout=timedelta(minutes=60), + tags=['example'], + catchup=False, +) +def example_redshift_data(): + @task(task_id="output_results") + def output_results_fn(id): + """This is a python decorator task that returns a Redshift query""" + hook = RedshiftDataHook() + + resp = hook.get_statement_result( + id=id, + ) + print(resp) + return resp + + # Run a SQL statement and wait for completion + redshift_query = RedshiftDataOperator( + task_id='redshift_query', + cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER, + database=REDSHIFT_DATABASE, + db_user=REDSHIFT_DATABASE_USER, + sql=REDSHIFT_QUERY, + poll_interval=POLL_INTERVAL, + await_result=True, + ) + + # Using a task-decorated function to output the list of tables in a Redshift cluster + output_results_fn(redshift_query.output) + + +example_redshift_data_dag = example_redshift_data() +# [END howto_redshift_data] diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py new file mode 100644 index 0000000000000..74459f58a5f8a --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + +if TYPE_CHECKING: + from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient + + +class RedshiftDataHook(AwsBaseHook): + """ + Interact with AWS Redshift Data, using the boto3 library + Hook attribute `conn` has all methods that listed in documentation + + .. seealso:: + - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-data.html + - https://docs.aws.amazon.com/redshift-data/latest/APIReference/Welcome.html + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "redshift-data" + super().__init__(*args, **kwargs) + + @property + def conn(self) -> 'RedshiftDataAPIServiceClient': + """Get the underlying boto3 RedshiftDataAPIService client (cached)""" + return super().conn diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py new file mode 100644 index 0000000000000..977bc68675191 --- /dev/null +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -0,0 +1,159 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +from time import sleep +from typing import TYPE_CHECKING, Optional + +if sys.version_info >= (3, 8): + from functools import cached_property +else: + from cached_property import cached_property + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class RedshiftDataOperator(BaseOperator): + """ + Executes SQL Statements against an Amazon Redshift cluster using Redshift Data + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RedshiftDataOperator` + + :param database: the name of the database + :param sql: the SQL statement text to run + :param cluster_identifier: unique identifier of a cluster + :param db_user: the database username + :param parameters: the parameters for the SQL statement + :param secret_arn: the name or ARN of the secret that enables db access + :param statement_name: the name of the SQL statement + :param with_event: indicates whether to send an event to EventBridge + :param await_result: indicates whether to wait for a result, if True wait, if False don't wait + :param poll_interval: how often in seconds to check the query status + :param aws_conn_id: aws connection to use + :param region: aws region to use + """ + + template_fields = ( + 'cluster_identifier', + 'database', + 'sql', + 'db_user', + 'parameters', + 'statement_name', + 'aws_conn_id', + 'region', + ) + template_ext = ('.sql',) + template_fields_renderers = {'sql': 'sql'} + + def __init__( + self, + database: str, + sql: str, + cluster_identifier: Optional[str] = None, + db_user: Optional[str] = None, + parameters: Optional[list] = None, + secret_arn: Optional[str] = None, + statement_name: Optional[str] = None, + with_event: bool = False, + await_result: bool = True, + poll_interval: int = 10, + aws_conn_id: str = 'aws_default', + region: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.database = database + self.sql = sql + self.cluster_identifier = cluster_identifier + self.db_user = db_user + self.parameters = parameters + self.secret_arn = secret_arn + self.statement_name = statement_name + self.with_event = with_event + self.await_result = await_result + if poll_interval > 0: + self.poll_interval = poll_interval + else: + self.log.warning( + "Invalid poll_interval:", + poll_interval, + ) + self.aws_conn_id = aws_conn_id + self.region = region + self.statement_id = None + + @cached_property + def hook(self) -> RedshiftDataHook: + """Create and return an RedshiftDataHook.""" + return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region) + + def execute_query(self): + resp = self.hook.conn.execute_statement( + ClusterIdentifier=self.cluster_identifier, + Database=self.database, + DbUser=self.db_user, + Sql=self.sql, + Parameters=self.parameters, + SecretArn=self.secret_arn, + StatementName=self.statement_name, + WithEvent=self.with_event, + ) + return resp['Id'] + + def wait_for_results(self, statement_id): + while True: + self.log.info("Polling statement %s", statement_id) + resp = self.hook.conn.describe_statement( + Id=statement_id, + ) + status = resp['Status'] + if status == 'FINISHED': + return status + elif status == 'FAILED' or status == 'ABORTED': + raise ValueError(f"Statement {statement_id!r} terminated with status {status}.") + else: + self.log.info(f"Query {status}") + sleep(self.poll_interval) + + def execute(self, context: 'Context') -> None: + """Execute a statement against Amazon Redshift""" + self.log.info(f"Executing statement: {self.sql}") + + self.statement_id = self.execute_query() + + if self.await_result: + self.wait_for_results(self.statement_id) + + return self.statement_id + + def on_kill(self) -> None: + """Cancel the submitted redshift query""" + if self.statement_id: + self.log.info('Received a kill signal.') + self.log.info('Stopping Query with statementId - %s', self.statement_id) + + try: + self.hook.conn.cancel_statement(Id=self.statement_id) + except Exception as ex: + self.log.error('Unable to cancel query. Exiting. %s', ex) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index b61afcf95ca5a..678e037c427b3 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -114,6 +114,12 @@ integrations: - /docs/apache-airflow-providers-amazon/operators/redshift_sql.rst - /docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst tags: [aws] + - integration-name: Amazon Redshift Data + external-doc-url: https://aws.amazon.com/redshift/ + logo: /integration-logos/aws/Amazon-Redshift_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/redshift_data.rst + tags: [aws] - integration-name: Amazon SageMaker external-doc-url: https://aws.amazon.com/sagemaker/ logo: /integration-logos/aws/Amazon-SageMaker_light-bg@4x.png @@ -271,6 +277,7 @@ operators: - airflow.providers.amazon.aws.operators.redshift - airflow.providers.amazon.aws.operators.redshift_sql - airflow.providers.amazon.aws.operators.redshift_cluster + - airflow.providers.amazon.aws.operators.redshift_data sensors: - integration-name: Amazon Athena @@ -401,6 +408,7 @@ hooks: - airflow.providers.amazon.aws.hooks.redshift - airflow.providers.amazon.aws.hooks.redshift_sql - airflow.providers.amazon.aws.hooks.redshift_cluster + - airflow.providers.amazon.aws.hooks.redshift_data - integration-name: Amazon Simple Storage Service (S3) python-modules: - airflow.providers.amazon.aws.hooks.s3 @@ -505,6 +513,8 @@ connection-types: connection-type: emr - hook-class-name: airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook connection-type: redshift + - hook-class-name: airflow.providers.amazon.aws.hooks.redshift.RedshiftDataHook + connection-type: aws secrets-backends: - airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend diff --git a/docs/apache-airflow-providers-amazon/index.rst b/docs/apache-airflow-providers-amazon/index.rst index 93dac67ebe52b..f2665cc1f8691 100644 --- a/docs/apache-airflow-providers-amazon/index.rst +++ b/docs/apache-airflow-providers-amazon/index.rst @@ -77,18 +77,19 @@ You can install this package on top of an existing Airflow 2.1+ installation via PIP requirements ---------------- -======================= =================== -PIP package Version required -======================= =================== -``apache-airflow`` ``>=2.1.0`` -``boto3`` ``>=1.15.0,<2.0.0`` -``jsonpath_ng`` ``>=1.5.3`` -``pandas`` ``>=0.17.1, <1.4`` -``redshift_connector`` ``~=2.0.888`` -``sqlalchemy_redshift`` ``~=0.8.6`` -``watchtower`` ``~=2.0.1`` -``mypy-boto3-rds`` ``>=1.21.0`` -======================= =================== +============================ =================== +PIP package Version required +============================ =================== +``apache-airflow`` ``>=2.1.0`` +``boto3`` ``>=1.15.0,<2.0.0`` +``jsonpath_ng`` ``>=1.5.3`` +``pandas`` ``>=0.17.1, <1.4`` +``redshift_connector`` ``~=2.0.888`` +``sqlalchemy_redshift`` ``~=0.8.6`` +``watchtower`` ``~=2.0.1`` +``mypy-boto3-rds`` ``>=1.21.0`` +``mypy-boto3-redshift-data`` ``>=1.21.0`` +============================ =================== Cross provider package dependencies ----------------------------------- diff --git a/docs/apache-airflow-providers-amazon/operators/redshift_data.rst b/docs/apache-airflow-providers-amazon/operators/redshift_data.rst new file mode 100644 index 0000000000000..73620089608ab --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/redshift_data.rst @@ -0,0 +1,52 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:RedshiftDataOperator: + +RedshiftDataOperator +==================== + +.. contents:: + :depth: 1 + :local: + +Overview +-------- + +Use the :class:`RedshiftDataOperator ` to execute +statements against an Amazon Redshift cluster. + +This differs from RedshiftSQLOperator in that it allows users to query and retrieve data via the AWS API and avoid the necessity of a Postgres connection. + +example_redshift_data_execute_sql.py +------------------------------------ + +Purpose +""""""" + +This is a basic example DAG for using :class:`RedshiftDataOperator ` +to execute statements against an Amazon Redshift cluster. + +List tables in database +""""""""""""""""""""""" + +In the following code we list the tables in the provided database. + +.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py + :language: python + :start-after: [START howto_redshift_data] + :end-before: [END howto_redshift_data] diff --git a/setup.py b/setup.py index dcc7064557266..9a50d69297d89 100644 --- a/setup.py +++ b/setup.py @@ -203,6 +203,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'sqlalchemy_redshift>=0.8.6', pandas_requirement, 'mypy-boto3-rds>=1.21.0', + 'mypy-boto3-redshift-data>=1.21.0', ] apache_beam = [ 'apache-beam>=2.33.0', diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py new file mode 100644 index 0000000000000..55f345c9261b9 --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook + + +class TestRedshiftDataHook: + def test_conn_attribute(self): + hook = RedshiftDataHook(aws_conn_id='aws_default', region_name='us-east-1') + assert hasattr(hook, 'conn') + assert hook.conn.__class__.__name__ == 'RedshiftDataAPIService' diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py new file mode 100644 index 0000000000000..c6a694332f1d4 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator + +CONN_ID = "aws_conn_test" +TASK_ID = "task_id" +SQL = "sql" +DATABASE = "database" +STATEMENT_ID = "statement_id" + + +class TestRedshiftDataOperator: + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_without_waiting(self, mock_conn): + mock_conn.execute_statement.return_value = {'Id': STATEMENT_ID} + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + await_result=False, + ) + operator.execute(None) + mock_conn.execute_statement.assert_called_once_with( + ClusterIdentifier=None, + Database=DATABASE, + DbUser=None, + Sql=SQL, + Parameters=None, + SecretArn=None, + StatementName=None, + WithEvent=False, + ) + mock_conn.describe_statement.assert_not_called() + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute(self, mock_conn): + parameters = [{"name": "id", "value": "1"}] + mock_conn.execute_statement.return_value = {'Id': STATEMENT_ID} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + parameters=parameters, + database=DATABASE, + ) + operator.execute(None) + mock_conn.execute_statement.assert_called_once_with( + ClusterIdentifier=None, + Database=DATABASE, + DbUser=None, + Sql=SQL, + Parameters=parameters, + SecretArn=None, + StatementName=None, + WithEvent=False, + ) + mock_conn.describe_statement.assert_called_once_with( + Id=STATEMENT_ID, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_on_kill_without_query(self, mock_conn): + mock_conn.execute_statement.return_value = {'Id': STATEMENT_ID} + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + await_result=False, + ) + operator.on_kill() + mock_conn.cancel_statement.assert_not_called() + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_on_kill_with_query(self, mock_conn): + mock_conn.execute_statement.return_value = {'Id': STATEMENT_ID} + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + await_result=False, + ) + operator.execute(None) + operator.on_kill() + mock_conn.cancel_statement.assert_called_once_with( + Id=STATEMENT_ID, + )