Skip to content

Commit

Permalink
fix: RedshiftDataHook and RdsHook not use cached connection (#24387)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Jun 19, 2022
1 parent d4eeede commit 796e0a0
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 34 deletions.
29 changes: 18 additions & 11 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
import configparser
import datetime
import logging
import sys
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Generic, Optional, Tuple, Type, TypeVar, Union

import boto3
import botocore
Expand All @@ -40,21 +39,18 @@
from botocore.client import ClientMeta
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
from slugify import slugify

if sys.version_info >= (3, 8):
from functools import cached_property
else:
from cached_property import cached_property

from dateutil.tz import tzlocal
from slugify import slugify

from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
from airflow.utils.log.logging_mixin import LoggingMixin

BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])


class BaseSessionFactory(LoggingMixin):
"""
Expand Down Expand Up @@ -372,7 +368,7 @@ def web_identity_token_loader():
return web_identity_token_loader


class AwsBaseHook(BaseHook):
class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
"""
Interact with AWS.
This class is a thin wrapper around the boto3 python library.
Expand Down Expand Up @@ -537,7 +533,7 @@ def conn_region_name(self) -> str:
def conn_partition(self) -> str:
return self.conn_client_meta.partition

def get_conn(self) -> Union[boto3.client, boto3.resource]:
def get_conn(self) -> BaseAwsConnection:
"""
Get the underlying boto3 client/resource (cached)
Expand Down Expand Up @@ -616,6 +612,17 @@ def decorator_f(self, *args, **kwargs):
return retry_decorator


class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
"""
Interact with AWS.
This class is a thin wrapper around the boto3 python library
with basic conn annotation.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`
"""


def _parse_s3_config(
config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
) -> Tuple[Optional[str], Optional[str]]:
Expand Down
18 changes: 4 additions & 14 deletions airflow/providers/amazon/aws/hooks/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

if TYPE_CHECKING:
from mypy_boto3_rds import RDSClient
from mypy_boto3_rds import RDSClient # noqa


class RdsHook(AwsBaseHook):
class RdsHook(AwsGenericHook['RDSClient']):
"""
Interact with AWS RDS using proper client from the boto3 library.
Expand All @@ -39,21 +39,11 @@ class RdsHook(AwsBaseHook):
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "rds"
super().__init__(*args, **kwargs)

@property
def conn(self) -> 'RDSClient':
"""
Get the underlying boto3 RDS client (cached)
:return: boto3 RDS client
:rtype: botocore.client.RDS
"""
return super().conn
13 changes: 4 additions & 9 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa


class RedshiftDataHook(AwsBaseHook):
class RedshiftDataHook(AwsGenericHook['RedshiftDataAPIServiceClient']):
"""
Interact with AWS Redshift Data, using the boto3 library
Hook attribute `conn` has all methods that listed in documentation
Expand All @@ -37,16 +37,11 @@ class RedshiftDataHook(AwsBaseHook):
are passed down to the underlying AwsBaseHook.
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift-data"
super().__init__(*args, **kwargs)

@property
def conn(self) -> 'RedshiftDataAPIServiceClient':
"""Get the underlying boto3 RedshiftDataAPIService client (cached)"""
return super().conn
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/hooks/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ def test_conn_attribute(self):
hook = RdsHook(aws_conn_id='aws_default', region_name='us-east-1')
assert hasattr(hook, 'conn')
assert hook.conn.__class__.__name__ == 'RDS'
conn = hook.conn
assert conn is hook.conn # Cached property
assert conn is hook.get_conn() # Same object as returned by `conn` property
3 changes: 3 additions & 0 deletions tests/providers/amazon/aws/hooks/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ def test_conn_attribute(self):
hook = RedshiftDataHook(aws_conn_id='aws_default', region_name='us-east-1')
assert hasattr(hook, 'conn')
assert hook.conn.__class__.__name__ == 'RedshiftDataAPIService'
conn = hook.conn
assert conn is hook.conn # Cached property
assert conn is hook.get_conn() # Same object as returned by `conn` property

0 comments on commit 796e0a0

Please sign in to comment.