diff --git a/airflow/providers/amazon/aws/operators/redshift.py b/airflow/providers/amazon/aws/operators/redshift.py index 9c1f8adfbf080..52d82b40fad9a 100644 --- a/airflow/providers/amazon/aws/operators/redshift.py +++ b/airflow/providers/amazon/aws/operators/redshift.py @@ -18,7 +18,7 @@ from typing import Dict, Iterable, Optional, Union from airflow.models import BaseOperator -from airflow.providers.amazon.aws.hooks.redshift import RedshiftSQLHook +from airflow.providers.amazon.aws.hooks.redshift import RedshiftHook, RedshiftSQLHook class RedshiftSQLOperator(BaseOperator): @@ -71,3 +71,85 @@ def execute(self, context: dict) -> None: self.log.info(f"Executing statement: {self.sql}") hook = self.get_hook() hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + + +class RedshiftResumeClusterOperator(BaseOperator): + """ + Resume a paused AWS Redshift Cluster + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RedshiftResumeClusterOperator` + + :param cluster_identifier: id of the AWS Redshift Cluster + :type cluster_identifier: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + """ + + template_fields = ("cluster_identifier",) + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + def __init__( + self, + *, + cluster_identifier: str, + aws_conn_id: str = "aws_default", + **kwargs, + ): + super().__init__(**kwargs) + self.cluster_identifier = cluster_identifier + self.aws_conn_id = aws_conn_id + + def execute(self, context): + redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier) + if cluster_state == 'paused': + self.log.info("Starting Redshift cluster %s", self.cluster_identifier) + redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) + else: + self.log.warning( + "Unable to resume cluster since cluster is currently in status: %s", cluster_state + ) + + +class RedshiftPauseClusterOperator(BaseOperator): + """ + Pause an AWS Redshift Cluster if it has status `available`. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:RedshiftPauseClusterOperator` + + :param cluster_identifier: id of the AWS Redshift Cluster + :type cluster_identifier: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + """ + + template_fields = ("cluster_identifier",) + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + def __init__( + self, + *, + cluster_identifier: str, + aws_conn_id: str = "aws_default", + **kwargs, + ): + super().__init__(**kwargs) + self.cluster_identifier = cluster_identifier + self.aws_conn_id = aws_conn_id + + def execute(self, context): + redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier) + if cluster_state == 'available': + self.log.info("Pausing Redshift cluster %s", self.cluster_identifier) + redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) + else: + self.log.warning( + "Unable to pause cluster since cluster is currently in status: %s", cluster_state + ) diff --git a/docs/apache-airflow-providers-amazon/operators/redshift.rst b/docs/apache-airflow-providers-amazon/operators/redshift.rst index 58ee26c3a9c5f..5a5d1d799fe4b 100644 --- a/docs/apache-airflow-providers-amazon/operators/redshift.rst +++ b/docs/apache-airflow-providers-amazon/operators/redshift.rst @@ -94,3 +94,25 @@ All together, here is our DAG: :language: python :start-after: [START redshift_operator_howto_guide] :end-before: [END redshift_operator_howto_guide] + + +.. _howto/operator:RedshiftResumeClusterOperator: + +Resume a Redshift Cluster +""""""""""""""""""""""""""""""""""""""""""" + +To resume a 'paused' AWS Redshift Cluster you can use +:class:`RedshiftResumeClusterOperator ` + +This Operator leverages the AWS CLI +`resume-cluster `__ API + +.. _howto/operator:RedshiftPauseClusterOperator: + +Pause a Redshift Cluster +""""""""""""""""""""""""""""""""""""""""""" + +To pause an 'available' AWS Redshift Cluster you can use +:class:`RedshiftPauseClusterOperator ` +This Operator leverages the AWS CLI +`pause-cluster `__ API diff --git a/tests/providers/amazon/aws/operators/test_redshift.py b/tests/providers/amazon/aws/operators/test_redshift.py index de9206ef83ff3..d43cf4a939b73 100644 --- a/tests/providers/amazon/aws/operators/test_redshift.py +++ b/tests/providers/amazon/aws/operators/test_redshift.py @@ -22,7 +22,11 @@ from parameterized import parameterized -from airflow.providers.amazon.aws.operators.redshift import RedshiftSQLOperator +from airflow.providers.amazon.aws.operators.redshift import ( + RedshiftPauseClusterOperator, + RedshiftResumeClusterOperator, + RedshiftSQLOperator, +) class TestRedshiftSQLOperator(unittest.TestCase): @@ -42,3 +46,63 @@ def test_redshift_operator(self, test_autocommit, test_parameters, mock_get_hook autocommit=test_autocommit, parameters=test_parameters, ) + + +class TestResumeClusterOperator: + def test_init(self): + redshift_operator = RedshiftResumeClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + ) + assert redshift_operator.task_id == "task_test" + assert redshift_operator.cluster_identifier == "test_cluster" + assert redshift_operator.aws_conn_id == "aws_conn_test" + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn") + def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn, mock_cluster_status): + mock_cluster_status.return_value = 'paused' + redshift_operator = RedshiftResumeClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + ) + redshift_operator.execute(None) + mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier='test_cluster') + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn") + def test_resume_cluster_not_called_when_cluster_is_not_paused(self, mock_get_conn, mock_cluster_status): + mock_cluster_status.return_value = 'available' + redshift_operator = RedshiftResumeClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + ) + redshift_operator.execute(None) + mock_get_conn.return_value.resume_cluster.assert_not_called() + + +class TestPauseClusterOperator: + def test_init(self): + redshift_operator = RedshiftPauseClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + ) + assert redshift_operator.task_id == "task_test" + assert redshift_operator.cluster_identifier == "test_cluster" + assert redshift_operator.aws_conn_id == "aws_conn_test" + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn") + def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn, mock_cluster_status): + mock_cluster_status.return_value = 'available' + redshift_operator = RedshiftPauseClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + ) + redshift_operator.execute(None) + mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier='test_cluster') + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.cluster_status") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn") + def test_pause_cluster_not_called_when_cluster_is_not_available(self, mock_get_conn, mock_cluster_status): + mock_cluster_status.return_value = 'paused' + redshift_operator = RedshiftPauseClusterOperator( + task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + ) + redshift_operator.execute(None) + mock_get_conn.return_value.pause_cluster.assert_not_called()