From 70b2a533397aadd646b438daf9250c7b1152fe44 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 7 Dec 2021 19:45:36 +0800 Subject: [PATCH] Helper for provide_session-decorated functions --- airflow/settings.py | 10 ++++++---- airflow/utils/session.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/airflow/settings.py b/airflow/settings.py index d7499ad0998bf..9cfed37bf3638 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -22,7 +22,7 @@ import os import sys import warnings -from typing import Optional +from typing import TYPE_CHECKING, Callable, List, Optional import pendulum import sqlalchemy @@ -37,6 +37,9 @@ from airflow.logging_config import configure_logging from airflow.utils.orm_event_handlers import setup_event_handlers +if TYPE_CHECKING: + from airflow.www.utils import UIAlert + log = logging.getLogger(__name__) @@ -77,7 +80,7 @@ DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) engine: Optional[Engine] = None -Session: Optional[SASession] = None +Session: Callable[..., SASession] # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -563,8 +566,7 @@ def initialize(): # UIAlert('Visit airflow.apache.org', html=True), # ] # -# DASHBOARD_UIALERTS: List["UIAlert"] -DASHBOARD_UIALERTS = [] +DASHBOARD_UIALERTS: List["UIAlert"] = [] # Prefix used to identify tables holding data moved during migration. AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved" diff --git a/airflow/utils/session.py b/airflow/utils/session.py index 3737635569e5e..e9bf9b461a375 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -17,7 +17,7 @@ import contextlib from functools import wraps from inspect import signature -from typing import Callable, Iterator, TypeVar +from typing import Callable, Iterator, TypeVar, cast from airflow import settings @@ -25,7 +25,7 @@ @contextlib.contextmanager def create_session() -> Iterator[settings.SASession]: """Contextmanager that will create and teardown a session.""" - session: settings.SASession = settings.Session() + session = settings.Session() try: yield session session.commit() @@ -69,3 +69,10 @@ def wrapper(*args, **kwargs) -> RT: return func(*args, session=session, **kwargs) return wrapper + + +# A fake session to use in functions decorated by provide_session. This allows +# the 'session' argument to be of type Session instead of Optional[Session], +# making it easier to type hint the function body without dealing with the None +# case that can never happen at runtime. +NEW_SESSION: settings.SASession = cast(settings.SASession, None)