From 845854201480b814c19c72895bed577ffa952432 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Sat, 11 Jun 2022 03:54:48 +0400 Subject: [PATCH] fix: RedshiftDataHook and RdsHook not use cached connection --- .../providers/amazon/aws/hooks/base_aws.py | 29 ++++++++++++------- airflow/providers/amazon/aws/hooks/rds.py | 18 +++--------- .../amazon/aws/hooks/redshift_data.py | 13 +++------ tests/providers/amazon/aws/hooks/test_rds.py | 3 ++ .../amazon/aws/hooks/test_redshift_data.py | 3 ++ 5 files changed, 32 insertions(+), 34 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 162be9ce47506..5e518adde04ac 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -27,10 +27,9 @@ import configparser import datetime import logging -import sys import warnings from functools import wraps -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generic, Optional, Tuple, Type, TypeVar, Union import boto3 import botocore @@ -40,21 +39,18 @@ from botocore.client import ClientMeta from botocore.config import Config from botocore.credentials import ReadOnlyCredentials -from slugify import slugify - -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from cached_property import cached_property - from dateutil.tz import tzlocal +from slugify import slugify +from airflow.compat.functools import cached_property from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models.connection import Connection from airflow.utils.log.logging_mixin import LoggingMixin +BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource]) + class BaseSessionFactory(LoggingMixin): """ @@ -372,7 +368,7 @@ def web_identity_token_loader(): return web_identity_token_loader -class AwsBaseHook(BaseHook): +class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]): """ Interact with AWS. This class is a thin wrapper around the boto3 python library. @@ -537,7 +533,7 @@ def conn_region_name(self) -> str: def conn_partition(self) -> str: return self.conn_client_meta.partition - def get_conn(self) -> Union[boto3.client, boto3.resource]: + def get_conn(self) -> BaseAwsConnection: """ Get the underlying boto3 client/resource (cached) @@ -616,6 +612,17 @@ def decorator_f(self, *args, **kwargs): return retry_decorator +class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]): + """ + Interact with AWS. + This class is a thin wrapper around the boto3 python library + with basic conn annotation. + + .. seealso:: + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook` + """ + + def _parse_s3_config( config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None ) -> Tuple[Optional[str], Optional[str]]: diff --git a/airflow/providers/amazon/aws/hooks/rds.py b/airflow/providers/amazon/aws/hooks/rds.py index 3539e951cade6..45fc99bd578ca 100644 --- a/airflow/providers/amazon/aws/hooks/rds.py +++ b/airflow/providers/amazon/aws/hooks/rds.py @@ -19,13 +19,13 @@ from typing import TYPE_CHECKING -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook if TYPE_CHECKING: - from mypy_boto3_rds import RDSClient + from mypy_boto3_rds import RDSClient # noqa -class RdsHook(AwsBaseHook): +class RdsHook(AwsGenericHook['RDSClient']): """ Interact with AWS RDS using proper client from the boto3 library. @@ -39,7 +39,7 @@ class RdsHook(AwsBaseHook): are passed down to the underlying AwsBaseHook. .. seealso:: - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook` :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -47,13 +47,3 @@ class RdsHook(AwsBaseHook): def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "rds" super().__init__(*args, **kwargs) - - @property - def conn(self) -> 'RDSClient': - """ - Get the underlying boto3 RDS client (cached) - - :return: boto3 RDS client - :rtype: botocore.client.RDS - """ - return super().conn diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index 74459f58a5f8a..e9a154368a928 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -18,13 +18,13 @@ from typing import TYPE_CHECKING -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook if TYPE_CHECKING: - from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient + from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa -class RedshiftDataHook(AwsBaseHook): +class RedshiftDataHook(AwsGenericHook['RedshiftDataAPIServiceClient']): """ Interact with AWS Redshift Data, using the boto3 library Hook attribute `conn` has all methods that listed in documentation @@ -37,7 +37,7 @@ class RedshiftDataHook(AwsBaseHook): are passed down to the underlying AwsBaseHook. .. seealso:: - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook` :param aws_conn_id: The Airflow connection used for AWS credentials. """ @@ -45,8 +45,3 @@ class RedshiftDataHook(AwsBaseHook): def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "redshift-data" super().__init__(*args, **kwargs) - - @property - def conn(self) -> 'RedshiftDataAPIServiceClient': - """Get the underlying boto3 RedshiftDataAPIService client (cached)""" - return super().conn diff --git a/tests/providers/amazon/aws/hooks/test_rds.py b/tests/providers/amazon/aws/hooks/test_rds.py index 89ec78246dcb8..ae0a1facec35e 100644 --- a/tests/providers/amazon/aws/hooks/test_rds.py +++ b/tests/providers/amazon/aws/hooks/test_rds.py @@ -25,3 +25,6 @@ def test_conn_attribute(self): hook = RdsHook(aws_conn_id='aws_default', region_name='us-east-1') assert hasattr(hook, 'conn') assert hook.conn.__class__.__name__ == 'RDS' + conn = hook.conn + assert conn is hook.conn # Cached property + assert conn is hook.get_conn() # Same object as returned by `conn` property diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index 55f345c9261b9..829fa0794fc34 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -24,3 +24,6 @@ def test_conn_attribute(self): hook = RedshiftDataHook(aws_conn_id='aws_default', region_name='us-east-1') assert hasattr(hook, 'conn') assert hook.conn.__class__.__name__ == 'RedshiftDataAPIService' + conn = hook.conn + assert conn is hook.conn # Cached property + assert conn is hook.get_conn() # Same object as returned by `conn` property