From 7e080456ad5efb5b945cb0b36c0adc505ed8fb16 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Sat, 18 Nov 2023 19:23:07 +0400 Subject: [PATCH 1/2] Remove backcompat with Airflow 2.3/2.4 in providers --- airflow/providers/common/sql/hooks/sql.py | 18 +----------------- airflow/providers/common/sql/hooks/sql.pyi | 6 +++--- airflow/providers/google/cloud/hooks/gcs.py | 9 +-------- .../google/cloud/secrets/secret_manager.py | 14 ++++++-------- .../microsoft/azure/secrets/key_vault.py | 14 ++++++-------- 5 files changed, 17 insertions(+), 44 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index ab4eda5d8ea78..bb85dedc1cdbd 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -34,12 +34,10 @@ from urllib.parse import urlparse import sqlparse -from packaging.version import Version from sqlalchemy import create_engine from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook -from airflow.version import version if TYPE_CHECKING: from pandas import DataFrame @@ -120,21 +118,7 @@ def connect(self, host: str, port: int, username: str, schema: str) -> Any: """ -# In case we are running it on Airflow 2.4+, we should use BaseHook, but on Airflow 2.3 and below -# We want the DbApiHook to derive from the original DbApiHook from airflow, because otherwise -# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" will refuse to -# accept the new Hooks as not derived from the original DbApiHook -if Version(version) < Version("2.4"): - try: - from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook - except ImportError: - # just in case we have a problem with circular import - BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] -else: - BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] - - -class DbApiHook(BaseForDbApiHook): +class DbApiHook(BaseHook): """ Abstract base class for sql hooks. diff --git a/airflow/providers/common/sql/hooks/sql.pyi b/airflow/providers/common/sql/hooks/sql.pyi index dedac037dfcb2..084993a919f49 100644 --- a/airflow/providers/common/sql/hooks/sql.pyi +++ b/airflow/providers/common/sql/hooks/sql.pyi @@ -32,8 +32,8 @@ Definition of the public interface for airflow.providers.common.sql.hooks.sql isort:skip_file """ from _typeshed import Incomplete -from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook -from typing import Any, Callable, Iterable, Mapping, Sequence +from airflow.hooks.base import BaseHook +from typing import Any, Callable, Iterable, Mapping, Sequence, Union from typing_extensions import Protocol def return_single_query_results( @@ -45,7 +45,7 @@ def fetch_one_handler(cursor) -> Union[list[tuple], None]: ... class ConnectorProtocol(Protocol): def connect(self, host: str, port: int, username: str, schema: str) -> Any: ... -class DbApiHook(BaseForDbApiHook): +class DbApiHook(BaseHook): conn_name_attr: str default_conn_name: str supports_autocommit: bool diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index d8bb36037fa10..1a74d09d55cbd 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -45,6 +45,7 @@ from airflow.providers.google.cloud.utils.helpers import normalize_directory_path from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook +from airflow.typing_compat import ParamSpec from airflow.utils import timezone from airflow.version import version @@ -54,14 +55,6 @@ from aiohttp import ClientSession from google.api_core.retry import Retry -try: - # Airflow 2.3 doesn't have this yet - from airflow.typing_compat import ParamSpec -except ImportError: - try: - from typing import ParamSpec # type: ignore[no-redef, attr-defined] - except ImportError: - from typing_extensions import ParamSpec RT = TypeVar("RT") T = TypeVar("T", bound=Callable) diff --git a/airflow/providers/google/cloud/secrets/secret_manager.py b/airflow/providers/google/cloud/secrets/secret_manager.py index fd8b8e33e28df..a40c6bfbe5fb4 100644 --- a/airflow/providers/google/cloud/secrets/secret_manager.py +++ b/airflow/providers/google/cloud/secrets/secret_manager.py @@ -28,7 +28,6 @@ from airflow.providers.google.cloud.utils.credentials_provider import get_credentials_and_project_id from airflow.secrets import BaseSecretsBackend from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.version import version as airflow_version log = logging.getLogger(__name__) @@ -154,13 +153,12 @@ def get_conn_uri(self, conn_id: str) -> str | None: :param conn_id: the connection id :return: deserialized Connection """ - if _parse_version(airflow_version) >= (2, 3): - warnings.warn( - f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " - "in a future release. Please use method `get_conn_value` instead.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) + warnings.warn( + f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " + "in a future release. Please use method `get_conn_value` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) return self.get_conn_value(conn_id) def get_variable(self, key: str) -> str | None: diff --git a/airflow/providers/microsoft/azure/secrets/key_vault.py b/airflow/providers/microsoft/azure/secrets/key_vault.py index 794788206c137..bfa9117b111dd 100644 --- a/airflow/providers/microsoft/azure/secrets/key_vault.py +++ b/airflow/providers/microsoft/azure/secrets/key_vault.py @@ -38,7 +38,6 @@ from airflow.providers.microsoft.azure.utils import get_sync_default_azure_credential from airflow.secrets import BaseSecretsBackend from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.version import version as airflow_version def _parse_version(val): @@ -170,13 +169,12 @@ def get_conn_uri(self, conn_id: str) -> str | None: :param conn_id: the connection id :return: deserialized Connection """ - if _parse_version(airflow_version) >= (2, 3): - warnings.warn( - f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " - "in a future release. Please use method `get_conn_value` instead.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) + warnings.warn( + f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed " + "in a future release. Please use method `get_conn_value` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) return self.get_conn_value(conn_id) def get_variable(self, key: str) -> str | None: From 2f6598b1f6c05315444f11cda72e0007f96093f5 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Sat, 18 Nov 2023 22:21:49 +0400 Subject: [PATCH 2/2] Revert changes in sql.py --- airflow/providers/common/sql/hooks/sql.py | 18 +++++++++++++++++- airflow/providers/common/sql/hooks/sql.pyi | 6 +++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index bb85dedc1cdbd..ab4eda5d8ea78 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -34,10 +34,12 @@ from urllib.parse import urlparse import sqlparse +from packaging.version import Version from sqlalchemy import create_engine from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.version import version if TYPE_CHECKING: from pandas import DataFrame @@ -118,7 +120,21 @@ def connect(self, host: str, port: int, username: str, schema: str) -> Any: """ -class DbApiHook(BaseHook): +# In case we are running it on Airflow 2.4+, we should use BaseHook, but on Airflow 2.3 and below +# We want the DbApiHook to derive from the original DbApiHook from airflow, because otherwise +# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" will refuse to +# accept the new Hooks as not derived from the original DbApiHook +if Version(version) < Version("2.4"): + try: + from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook + except ImportError: + # just in case we have a problem with circular import + BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] +else: + BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef] + + +class DbApiHook(BaseForDbApiHook): """ Abstract base class for sql hooks. diff --git a/airflow/providers/common/sql/hooks/sql.pyi b/airflow/providers/common/sql/hooks/sql.pyi index 084993a919f49..dedac037dfcb2 100644 --- a/airflow/providers/common/sql/hooks/sql.pyi +++ b/airflow/providers/common/sql/hooks/sql.pyi @@ -32,8 +32,8 @@ Definition of the public interface for airflow.providers.common.sql.hooks.sql isort:skip_file """ from _typeshed import Incomplete -from airflow.hooks.base import BaseHook -from typing import Any, Callable, Iterable, Mapping, Sequence, Union +from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook +from typing import Any, Callable, Iterable, Mapping, Sequence from typing_extensions import Protocol def return_single_query_results( @@ -45,7 +45,7 @@ def fetch_one_handler(cursor) -> Union[list[tuple], None]: ... class ConnectorProtocol(Protocol): def connect(self, host: str, port: int, username: str, schema: str) -> Any: ... -class DbApiHook(BaseHook): +class DbApiHook(BaseForDbApiHook): conn_name_attr: str default_conn_name: str supports_autocommit: bool