Skip to content

Commit

Permalink
Helper for provide_session-decorated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Dec 7, 2021
1 parent 6dd0a0d commit 70b2a53
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
10 changes: 6 additions & 4 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import sys
import warnings
from typing import Optional
from typing import TYPE_CHECKING, Callable, List, Optional

import pendulum
import sqlalchemy
Expand All @@ -37,6 +37,9 @@
from airflow.logging_config import configure_logging
from airflow.utils.orm_event_handlers import setup_event_handlers

if TYPE_CHECKING:
from airflow.www.utils import UIAlert

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -77,7 +80,7 @@
DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))

engine: Optional[Engine] = None
Session: Optional[SASession] = None
Session: Callable[..., SASession]

# The JSON library to use for DAG Serialization and De-Serialization
json = json
Expand Down Expand Up @@ -563,8 +566,7 @@ def initialize():
# UIAlert('Visit <a href="http://airflow.apache.org">airflow.apache.org</a>', html=True),
# ]
#
# DASHBOARD_UIALERTS: List["UIAlert"]
DASHBOARD_UIALERTS = []
DASHBOARD_UIALERTS: List["UIAlert"] = []

# Prefix used to identify tables holding data moved during migration.
AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved"
11 changes: 9 additions & 2 deletions airflow/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import contextlib
from functools import wraps
from inspect import signature
from typing import Callable, Iterator, TypeVar
from typing import Callable, Iterator, TypeVar, cast

from airflow import settings


@contextlib.contextmanager
def create_session() -> Iterator[settings.SASession]:
"""Contextmanager that will create and teardown a session."""
session: settings.SASession = settings.Session()
session = settings.Session()
try:
yield session
session.commit()
Expand Down Expand Up @@ -69,3 +69,10 @@ def wrapper(*args, **kwargs) -> RT:
return func(*args, session=session, **kwargs)

return wrapper


# A fake session to use in functions decorated by provide_session. This allows
# the 'session' argument to be of type Session instead of Optional[Session],
# making it easier to type hint the function body without dealing with the None
# case that can never happen at runtime.
NEW_SESSION: settings.SASession = cast(settings.SASession, None)

0 comments on commit 70b2a53

Please sign in to comment.