Skip to content

Commit

Permalink
exception types and doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
topherinternational committed Nov 5, 2024
1 parent fa10829 commit e2bce9a
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions providers/src/airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,17 @@ class PostgresHook(DbApiHook):
"aws_default" connection to get the temporary token unless you override
in extras.
extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}``
For Redshift, also use redshift in the extra connection parameters and
set it to true. The cluster-identifier is extracted from the beginning of
the host field, so is optional. It can however be overridden in the extra field.
extras example: ``{"iam":true, "redshift":true, "cluster-identifier": "my_cluster_id"}``
For Redshift Serverless, use redshift-serverless in the extra connection parameters and
set it to true. The workgroup-name is extracted from the beginning of
the host field, so is optional. It can however be overridden in the extra field.
extras example: ``{"iam":true, "redshift-serverless":true, "workgroup-name": "my_serverless_workgroup"}``
:param postgres_conn_id: The :ref:`postgres conn id <howto/connection:postgres>`
reference to a specific postgres database.
:param options: Optional. Specifies command-line options to send to the server
Expand Down Expand Up @@ -249,9 +255,9 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
try:
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
except ImportError:
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowException(
raise AirflowOptionalProviderFeatureException(
"apache-airflow-providers-amazon not installed, run: "
"pip install 'apache-airflow-providers-postgres[amazon]'."
)
Expand All @@ -264,7 +270,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
cluster_identifier = conn.extra_dejson.get("cluster-identifier", conn.host.split(".")[0])
redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="redshift").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift/client/get_cluster_credentials.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
DbUser=login,
DbName=self.database or conn.schema,
Expand All @@ -277,12 +283,12 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
port = conn.port or 5439
# Pull the workgroup-name from the query params/extras, if not there then pull it from the
# beginning of the Redshift URL
# ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
# ex. workgroup-name.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns workgroup-name
workgroup_name = conn.extra_dejson.get("workgroup-name", conn.host.split(".")[0])
redshift_serverless_client = AwsBaseHook(
aws_conn_id=aws_conn_id, client_type="redshift-serverless"
).conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift-serverless/client/get_credentials.html#RedshiftServerless.Client.get_credentials
cluster_creds = redshift_serverless_client.get_credentials(
dbName=self.database or conn.schema,
workgroupName=workgroup_name,
Expand All @@ -292,7 +298,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
else:
port = conn.port or 5432
rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="rds").conn
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds/client/generate_db_auth_token.html#RDS.Client.generate_db_auth_token
token = rds_client.generate_db_auth_token(conn.host, port, conn.login)
return login, token, port

Expand Down Expand Up @@ -389,9 +395,9 @@ def _get_openlineage_redshift_authority_part(self, connection) -> str:
try:
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
except ImportError:
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowException(
raise AirflowOptionalProviderFeatureException(
"apache-airflow-providers-amazon not installed, run: "
"pip install 'apache-airflow-providers-postgres[amazon]'."
)
Expand Down

0 comments on commit e2bce9a

Please sign in to comment.