From 34e2ee55e7a5b0d330a8e384dbc061f971f20333 Mon Sep 17 00:00:00 2001 From: Topher Anderson <48180628+topherinternational@users.noreply.github.com> Date: Wed, 6 Nov 2024 08:22:23 +0100 Subject: [PATCH] Add AWS Redshift Serverless support to PostgresHook (#43669) * add AWS Redshift Serverless support to PostgresHook * exception types and doc strings --- .../providers/postgres/hooks/postgres.py | 36 +++++++-- .../tests/postgres/hooks/test_postgres.py | 74 +++++++++++++++++++ 2 files changed, 104 insertions(+), 6 deletions(-) diff --git a/providers/src/airflow/providers/postgres/hooks/postgres.py b/providers/src/airflow/providers/postgres/hooks/postgres.py index e9888fa1e7ba..725602bf57e4 100644 --- a/providers/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/src/airflow/providers/postgres/hooks/postgres.py @@ -59,11 +59,17 @@ class PostgresHook(DbApiHook): "aws_default" connection to get the temporary token unless you override in extras. extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}`` + For Redshift, also use redshift in the extra connection parameters and set it to true. The cluster-identifier is extracted from the beginning of the host field, so is optional. It can however be overridden in the extra field. extras example: ``{"iam":true, "redshift":true, "cluster-identifier": "my_cluster_id"}`` + For Redshift Serverless, use redshift-serverless in the extra connection parameters and + set it to true. The workgroup-name is extracted from the beginning of + the host field, so is optional. It can however be overridden in the extra field. + extras example: ``{"iam":true, "redshift-serverless":true, "workgroup-name": "my_serverless_workgroup"}`` + :param postgres_conn_id: The :ref:`postgres conn id ` reference to a specific postgres database. :param options: Optional. Specifies command-line options to send to the server @@ -172,8 +178,10 @@ def get_conn(self) -> connection: if arg_name not in [ "iam", "redshift", + "redshift-serverless", "cursor", "cluster-identifier", + "workgroup-name", "aws_conn_id", ]: conn_args[arg_name] = arg_val @@ -247,9 +255,9 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: try: from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook except ImportError: - from airflow.exceptions import AirflowException + from airflow.exceptions import AirflowOptionalProviderFeatureException - raise AirflowException( + raise AirflowOptionalProviderFeatureException( "apache-airflow-providers-amazon not installed, run: " "pip install 'apache-airflow-providers-postgres[amazon]'." ) @@ -262,7 +270,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster cluster_identifier = conn.extra_dejson.get("cluster-identifier", conn.host.split(".")[0]) redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="redshift").conn - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift/client/get_cluster_credentials.html#Redshift.Client.get_cluster_credentials cluster_creds = redshift_client.get_cluster_credentials( DbUser=login, DbName=self.database or conn.schema, @@ -271,10 +279,26 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: ) token = cluster_creds["DbPassword"] login = cluster_creds["DbUser"] + elif conn.extra_dejson.get("redshift-serverless", False): + port = conn.port or 5439 + # Pull the workgroup-name from the query params/extras, if not there then pull it from the + # beginning of the Redshift URL + # ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns workgroup-name + workgroup_name = conn.extra_dejson.get("workgroup-name", conn.host.split(".")[0]) + redshift_serverless_client = AwsBaseHook( + aws_conn_id=aws_conn_id, client_type="redshift-serverless" + ).conn + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#RedshiftServerless.Client.get_credentials + cluster_creds = redshift_serverless_client.get_credentials( + dbName=self.database or conn.schema, + workgroupName=workgroup_name, + ) + token = cluster_creds["DbPassword"] + login = cluster_creds["DbUser"] else: port = conn.port or 5432 rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="rds").conn - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds/client/generate_db_auth_token.html#RDS.Client.generate_db_auth_token token = rds_client.generate_db_auth_token(conn.host, port, conn.login) return login, token, port @@ -371,9 +395,9 @@ def _get_openlineage_redshift_authority_part(self, connection) -> str: try: from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook except ImportError: - from airflow.exceptions import AirflowException + from airflow.exceptions import AirflowOptionalProviderFeatureException - raise AirflowException( + raise AirflowOptionalProviderFeatureException( "apache-airflow-providers-amazon not installed, run: " "pip install 'apache-airflow-providers-postgres[amazon]'." ) diff --git a/providers/tests/postgres/hooks/test_postgres.py b/providers/tests/postgres/hooks/test_postgres.py index 740f957643f8..bd8bf8d320aa 100644 --- a/providers/tests/postgres/hooks/test_postgres.py +++ b/providers/tests/postgres/hooks/test_postgres.py @@ -239,6 +239,80 @@ def test_get_conn_rds_iam_redshift( port=(port or 5439), ) + @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") + @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook") + @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"]) + @pytest.mark.parametrize("port", [5432, 5439, None]) + @pytest.mark.parametrize( + "host,conn_workgroup_name,expected_workgroup_name", + [ + ( + "serverless-workgroup.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + NOTSET, + "serverless-workgroup", + ), + ( + "cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com", + "different-workgroup", + "different-workgroup", + ), + ], + ) + def test_get_conn_rds_iam_redshift_serverless( + self, + mock_aws_hook_class, + mock_connect, + aws_conn_id, + port, + host, + conn_workgroup_name, + expected_workgroup_name, + ): + mock_conn_extra = { + "iam": True, + "redshift-serverless": True, + } + if aws_conn_id is not NOTSET: + mock_conn_extra["aws_conn_id"] = aws_conn_id + if conn_workgroup_name is not NOTSET: # change to workgroup + mock_conn_extra["workgroup-name"] = conn_workgroup_name # change to workgroup + + self.connection.extra = json.dumps(mock_conn_extra) + self.connection.host = host + self.connection.port = port + mock_db_user = f"IAM:{self.connection.login}" + mock_db_pass = "aws_token" + + # Mock AWS Connection + mock_aws_hook_instance = mock_aws_hook_class.return_value + mock_client = mock.MagicMock() + mock_client.get_credentials.return_value = { + "DbPassword": mock_db_pass, + "DbUser": mock_db_user, + } + type(mock_aws_hook_instance).conn = mock.PropertyMock(return_value=mock_client) + + self.db_hook.get_conn() + # Check AwsHook initialization + mock_aws_hook_class.assert_called_once_with( + # If aws_conn_id not set than fallback to aws_default + aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default", + client_type="redshift-serverless", + ) + # Check boto3 'redshift' client method `get_cluster_credentials` call args + mock_client.get_credentials.assert_called_once_with( + dbName=self.connection.schema, + workgroupName=expected_workgroup_name, + ) + # Check expected psycopg2 connection call args + mock_connect.assert_called_once_with( + user=mock_db_user, + password=mock_db_pass, + host=host, + dbname=self.connection.schema, + port=(port or 5439), + ) + def test_get_uri_from_connection_without_database_override(self): self.db_hook.get_connection = mock.MagicMock( return_value=Connection(