Skip to content

Commit

Permalink
Add AWS Redshift Serverless support to PostgresHook (apache#43669)
Browse files Browse the repository at this point in the history
* add AWS Redshift Serverless support to PostgresHook

* exception types and doc strings
  • Loading branch information
topherinternational authored and ellisms committed Nov 13, 2024
1 parent 76fb8e2 commit 34e2ee5
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 6 deletions.
36 changes: 30 additions & 6 deletions providers/src/airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,17 @@ class PostgresHook(DbApiHook):
"aws_default" connection to get the temporary token unless you override
in extras.
extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}``
For Redshift, also use redshift in the extra connection parameters and
set it to true. The cluster-identifier is extracted from the beginning of
the host field, so is optional. It can however be overridden in the extra field.
extras example: ``{"iam":true, "redshift":true, "cluster-identifier": "my_cluster_id"}``
For Redshift Serverless, use redshift-serverless in the extra connection parameters and
set it to true. The workgroup-name is extracted from the beginning of
the host field, so is optional. It can however be overridden in the extra field.
extras example: ``{"iam":true, "redshift-serverless":true, "workgroup-name": "my_serverless_workgroup"}``
:param postgres_conn_id: The :ref:`postgres conn id <howto/connection:postgres>`
reference to a specific postgres database.
:param options: Optional. Specifies command-line options to send to the server
Expand Down Expand Up @@ -172,8 +178,10 @@ def get_conn(self) -> connection:
if arg_name not in [
"iam",
"redshift",
"redshift-serverless",
"cursor",
"cluster-identifier",
"workgroup-name",
"aws_conn_id",
]:
conn_args[arg_name] = arg_val
Expand Down Expand Up @@ -247,9 +255,9 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
try:
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
except ImportError:
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowException(
raise AirflowOptionalProviderFeatureException(
"apache-airflow-providers-amazon not installed, run: "
"pip install 'apache-airflow-providers-postgres[amazon]'."
)
Expand All @@ -262,7 +270,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
cluster_identifier = conn.extra_dejson.get("cluster-identifier", conn.host.split(".")[0])
redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="redshift").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift/client/get_cluster_credentials.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
DbUser=login,
DbName=self.database or conn.schema,
Expand All @@ -271,10 +279,26 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
)
token = cluster_creds["DbPassword"]
login = cluster_creds["DbUser"]
elif conn.extra_dejson.get("redshift-serverless", False):
port = conn.port or 5439
# Pull the workgroup-name from the query params/extras, if not there then pull it from the
# beginning of the Redshift URL
# ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns workgroup-name
workgroup_name = conn.extra_dejson.get("workgroup-name", conn.host.split(".")[0])
redshift_serverless_client = AwsBaseHook(
aws_conn_id=aws_conn_id, client_type="redshift-serverless"
).conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#RedshiftServerless.Client.get_credentials
cluster_creds = redshift_serverless_client.get_credentials(
dbName=self.database or conn.schema,
workgroupName=workgroup_name,
)
token = cluster_creds["DbPassword"]
login = cluster_creds["DbUser"]
else:
port = conn.port or 5432
rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="rds").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds/client/generate_db_auth_token.html#RDS.Client.generate_db_auth_token
token = rds_client.generate_db_auth_token(conn.host, port, conn.login)
return login, token, port

Expand Down Expand Up @@ -371,9 +395,9 @@ def _get_openlineage_redshift_authority_part(self, connection) -> str:
try:
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
except ImportError:
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowException(
raise AirflowOptionalProviderFeatureException(
"apache-airflow-providers-amazon not installed, run: "
"pip install 'apache-airflow-providers-postgres[amazon]'."
)
Expand Down
74 changes: 74 additions & 0 deletions providers/tests/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,80 @@ def test_get_conn_rds_iam_redshift(
port=(port or 5439),
)

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")
@pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
@pytest.mark.parametrize("port", [5432, 5439, None])
@pytest.mark.parametrize(
"host,conn_workgroup_name,expected_workgroup_name",
[
(
"serverless-workgroup.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
NOTSET,
"serverless-workgroup",
),
(
"cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com",
"different-workgroup",
"different-workgroup",
),
],
)
def test_get_conn_rds_iam_redshift_serverless(
self,
mock_aws_hook_class,
mock_connect,
aws_conn_id,
port,
host,
conn_workgroup_name,
expected_workgroup_name,
):
mock_conn_extra = {
"iam": True,
"redshift-serverless": True,
}
if aws_conn_id is not NOTSET:
mock_conn_extra["aws_conn_id"] = aws_conn_id
if conn_workgroup_name is not NOTSET: # change to workgroup
mock_conn_extra["workgroup-name"] = conn_workgroup_name # change to workgroup

self.connection.extra = json.dumps(mock_conn_extra)
self.connection.host = host
self.connection.port = port
mock_db_user = f"IAM:{self.connection.login}"
mock_db_pass = "aws_token"

# Mock AWS Connection
mock_aws_hook_instance = mock_aws_hook_class.return_value
mock_client = mock.MagicMock()
mock_client.get_credentials.return_value = {
"DbPassword": mock_db_pass,
"DbUser": mock_db_user,
}
type(mock_aws_hook_instance).conn = mock.PropertyMock(return_value=mock_client)

self.db_hook.get_conn()
# Check AwsHook initialization
mock_aws_hook_class.assert_called_once_with(
# If aws_conn_id not set than fallback to aws_default
aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else "aws_default",
client_type="redshift-serverless",
)
# Check boto3 'redshift' client method `get_cluster_credentials` call args
mock_client.get_credentials.assert_called_once_with(
dbName=self.connection.schema,
workgroupName=expected_workgroup_name,
)
# Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
user=mock_db_user,
password=mock_db_pass,
host=host,
dbname=self.connection.schema,
port=(port or 5439),
)

def test_get_uri_from_connection_without_database_override(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
Expand Down

0 comments on commit 34e2ee5

Please sign in to comment.