Skip to content

Commit

Permalink
Helper for provide_session-decorated functions (#20104)
Browse files Browse the repository at this point in the history
* Helper for provide_session-decorated functions

* Apply NEW_SESSION trick on XCom

(cherry picked from commit a80ac1e)
  • Loading branch information
uranusjr authored and jedcunningham committed Feb 17, 2022
1 parent 016929f commit dda864d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
24 changes: 12 additions & 12 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.utils import timezone
from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -90,7 +90,7 @@ def set(
dag_id: str,
task_id: str,
run_id: str,
session: Optional[Session] = None,
session: Session = NEW_SESSION,
) -> None:
"""Store an XCom value.
Expand All @@ -116,7 +116,7 @@ def set(
task_id: str,
dag_id: str,
execution_date: datetime.datetime,
session: Optional[Session] = None,
session: Session = NEW_SESSION,
) -> None:
""":sphinx-autoapi-skip:"""

Expand All @@ -129,7 +129,7 @@ def set(
task_id: str,
dag_id: str,
execution_date: Optional[datetime.datetime] = None,
session: Session = None,
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -170,7 +170,7 @@ def get_one(
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
include_prior_dates: bool = False,
session: Optional[Session] = None,
session: Session = NEW_SESSION,
) -> Optional[Any]:
"""Retrieve an XCom value, optionally meeting certain criteria.
Expand Down Expand Up @@ -207,7 +207,7 @@ def get_one(
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
include_prior_dates: bool = False,
session: Optional[Session] = None,
session: Session = NEW_SESSION,
) -> Optional[Any]:
""":sphinx-autoapi-skip:"""

Expand All @@ -220,7 +220,7 @@ def get_one(
task_id: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
session: Session = None,
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
) -> Optional[Any]:
Expand Down Expand Up @@ -265,7 +265,7 @@ def get_many(
dag_ids: Union[str, Iterable[str], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Optional[Session] = None,
session: Session = NEW_SESSION,
) -> Query:
"""Composes a query to get one or more XCom entries.
Expand Down Expand Up @@ -300,7 +300,7 @@ def get_many(
dag_ids: Union[str, Iterable[str], None] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Optional[Session] = None,
session: Session = NEW_SESSION,
) -> Query:
""":sphinx-autoapi-skip:"""

Expand All @@ -314,7 +314,7 @@ def get_many(
dag_ids: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = None,
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
) -> Query:
Expand Down Expand Up @@ -397,7 +397,7 @@ def clear(
execution_date: pendulum.DateTime,
dag_id: str,
task_id: str,
session: Optional[Session] = None,
session: Session = NEW_SESSION,
) -> None:
""":sphinx-autoapi-skip:"""

Expand All @@ -409,7 +409,7 @@ def clear(
dag_id: Optional[str] = None,
task_id: Optional[str] = None,
run_id: Optional[str] = None,
session: Session = None,
session: Session = NEW_SESSION,
) -> None:
""":sphinx-autoapi-skip:"""
# Given the historic order of this function (execution_date was first argument) to add a new optional
Expand Down
10 changes: 6 additions & 4 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import sys
import warnings
from typing import Optional
from typing import TYPE_CHECKING, Callable, List, Optional

import pendulum
import sqlalchemy
Expand All @@ -37,6 +37,9 @@
from airflow.logging_config import configure_logging
from airflow.utils.orm_event_handlers import setup_event_handlers

if TYPE_CHECKING:
from airflow.www.utils import UIAlert

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -77,7 +80,7 @@
DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))

engine: Optional[Engine] = None
Session: Optional[SASession] = None
Session: Callable[..., SASession]

# The JSON library to use for DAG Serialization and De-Serialization
json = json
Expand Down Expand Up @@ -563,8 +566,7 @@ def initialize():
# UIAlert('Visit <a href="http://airflow.apache.org">airflow.apache.org</a>', html=True),
# ]
#
# DASHBOARD_UIALERTS: List["UIAlert"]
DASHBOARD_UIALERTS = []
DASHBOARD_UIALERTS: List["UIAlert"] = []

# Prefix used to identify tables holding data moved during migration.
AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved"
11 changes: 9 additions & 2 deletions airflow/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import contextlib
from functools import wraps
from inspect import signature
from typing import Callable, Iterator, TypeVar
from typing import Callable, Iterator, TypeVar, cast

from airflow import settings


@contextlib.contextmanager
def create_session() -> Iterator[settings.SASession]:
"""Contextmanager that will create and teardown a session."""
session: settings.SASession = settings.Session()
session = settings.Session()
try:
yield session
session.commit()
Expand Down Expand Up @@ -105,3 +105,10 @@ def create_global_lock(session=None, pg_lock_id=1, lock_name='init', mysql_lock_
if dialect.name == 'mssql':
# TODO: make locking works for MSSQL
pass


# A fake session to use in functions decorated by provide_session. This allows
# the 'session' argument to be of type Session instead of Optional[Session],
# making it easier to type hint the function body without dealing with the None
# case that can never happen at runtime.
NEW_SESSION: settings.SASession = cast(settings.SASession, None)

0 comments on commit dda864d

Please sign in to comment.