From d5721619bc970e42abdc474362fcf9c151bc33a8 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 7 Dec 2021 21:50:34 +0800 Subject: [PATCH] Helper for provide_session-decorated functions (#20104) * Helper for provide_session-decorated functions * Apply NEW_SESSION trick on XCom (cherry picked from commit a80ac1eecc0ea187de7984510b4ef6f981b97196) --- airflow/models/xcom.py | 24 ++++++++++++------------ airflow/settings.py | 10 ++++++---- airflow/utils/session.py | 11 +++++++++-- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 4bb9689e7dda6..5efaa0ac54f61 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -32,7 +32,7 @@ from airflow.utils import timezone from airflow.utils.helpers import is_container from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime log = logging.getLogger(__name__) @@ -90,7 +90,7 @@ def set( dag_id: str, task_id: str, run_id: str, - session: Optional[Session] = None, + session: Session = NEW_SESSION, ) -> None: """Store an XCom value. @@ -116,7 +116,7 @@ def set( task_id: str, dag_id: str, execution_date: datetime.datetime, - session: Optional[Session] = None, + session: Session = NEW_SESSION, ) -> None: """:sphinx-autoapi-skip:""" @@ -129,7 +129,7 @@ def set( task_id: str, dag_id: str, execution_date: Optional[datetime.datetime] = None, - session: Session = None, + session: Session = NEW_SESSION, *, run_id: Optional[str] = None, ) -> None: @@ -170,7 +170,7 @@ def get_one( task_id: Optional[str] = None, dag_id: Optional[str] = None, include_prior_dates: bool = False, - session: Optional[Session] = None, + session: Session = NEW_SESSION, ) -> Optional[Any]: """Retrieve an XCom value, optionally meeting certain criteria. @@ -207,7 +207,7 @@ def get_one( task_id: Optional[str] = None, dag_id: Optional[str] = None, include_prior_dates: bool = False, - session: Optional[Session] = None, + session: Session = NEW_SESSION, ) -> Optional[Any]: """:sphinx-autoapi-skip:""" @@ -220,7 +220,7 @@ def get_one( task_id: Optional[Union[str, Iterable[str]]] = None, dag_id: Optional[Union[str, Iterable[str]]] = None, include_prior_dates: bool = False, - session: Session = None, + session: Session = NEW_SESSION, *, run_id: Optional[str] = None, ) -> Optional[Any]: @@ -265,7 +265,7 @@ def get_many( dag_ids: Union[str, Iterable[str], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, - session: Optional[Session] = None, + session: Session = NEW_SESSION, ) -> Query: """Composes a query to get one or more XCom entries. @@ -300,7 +300,7 @@ def get_many( dag_ids: Union[str, Iterable[str], None] = None, include_prior_dates: bool = False, limit: Optional[int] = None, - session: Optional[Session] = None, + session: Session = NEW_SESSION, ) -> Query: """:sphinx-autoapi-skip:""" @@ -314,7 +314,7 @@ def get_many( dag_ids: Optional[Union[str, Iterable[str]]] = None, include_prior_dates: bool = False, limit: Optional[int] = None, - session: Session = None, + session: Session = NEW_SESSION, *, run_id: Optional[str] = None, ) -> Query: @@ -397,7 +397,7 @@ def clear( execution_date: pendulum.DateTime, dag_id: str, task_id: str, - session: Optional[Session] = None, + session: Session = NEW_SESSION, ) -> None: """:sphinx-autoapi-skip:""" @@ -409,7 +409,7 @@ def clear( dag_id: Optional[str] = None, task_id: Optional[str] = None, run_id: Optional[str] = None, - session: Session = None, + session: Session = NEW_SESSION, ) -> None: """:sphinx-autoapi-skip:""" # Given the historic order of this function (execution_date was first argument) to add a new optional diff --git a/airflow/settings.py b/airflow/settings.py index f9b97a25c4c2a..139d6a40b18d6 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -22,7 +22,7 @@ import os import sys import warnings -from typing import Optional +from typing import TYPE_CHECKING, Callable, List, Optional import pendulum import sqlalchemy @@ -37,6 +37,9 @@ from airflow.logging_config import configure_logging from airflow.utils.orm_event_handlers import setup_event_handlers +if TYPE_CHECKING: + from airflow.www.utils import UIAlert + log = logging.getLogger(__name__) @@ -77,7 +80,7 @@ DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) engine: Optional[Engine] = None -Session: Optional[SASession] = None +Session: Callable[..., SASession] # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -563,8 +566,7 @@ def initialize(): # UIAlert('Visit airflow.apache.org', html=True), # ] # -# DASHBOARD_UIALERTS: List["UIAlert"] -DASHBOARD_UIALERTS = [] +DASHBOARD_UIALERTS: List["UIAlert"] = [] # Prefix used to identify tables holding data moved during migration. AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved" diff --git a/airflow/utils/session.py b/airflow/utils/session.py index 9636fc401e6cc..f0c31687ff1ab 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -18,7 +18,7 @@ import contextlib from functools import wraps from inspect import signature -from typing import Callable, Iterator, TypeVar +from typing import Callable, Iterator, TypeVar, cast from airflow import settings @@ -26,7 +26,7 @@ @contextlib.contextmanager def create_session() -> Iterator[settings.SASession]: """Contextmanager that will create and teardown a session.""" - session: settings.SASession = settings.Session() + session = settings.Session() try: yield session session.commit() @@ -105,3 +105,10 @@ def create_global_lock(session=None, pg_lock_id=1, lock_name='init', mysql_lock_ if dialect.name == 'mssql': # TODO: make locking works for MSSQL pass + + +# A fake session to use in functions decorated by provide_session. This allows +# the 'session' argument to be of type Session instead of Optional[Session], +# making it easier to type hint the function body without dealing with the None +# case that can never happen at runtime. +NEW_SESSION: settings.SASession = cast(settings.SASession, None)