Skip to content

Commit

Permalink
Lazy Jinja2 context (#20217)
Browse files Browse the repository at this point in the history
Co-authored-by: Jed Cunningham <[email protected]>
  • Loading branch information
uranusjr and jedcunningham authored Dec 14, 2021
1 parent 5011cb3 commit 181d60c
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 46 deletions.
25 changes: 16 additions & 9 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
from airflow.utils.helpers import render_template_as_native, render_template_to_string, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -1042,7 +1042,11 @@ def __setstate__(self, state):
self.__dict__ = state
self._log = logging.getLogger("airflow.task.operators")

def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Environment] = None) -> None:
def render_template_fields(
self,
context: Context,
jinja_env: Optional[jinja2.Environment] = None,
) -> None:
"""
Template all attributes listed in template_fields. Note this operation is irreversible.
Expand All @@ -1060,7 +1064,7 @@ def _do_render_template_fields(
self,
parent: Any,
template_fields: Iterable[str],
context: Dict,
context: Context,
jinja_env: jinja2.Environment,
seen_oids: Set,
) -> None:
Expand All @@ -1073,7 +1077,7 @@ def _do_render_template_fields(
def render_template(
self,
content: Any,
context: Dict,
context: Context,
jinja_env: Optional[jinja2.Environment] = None,
seen_oids: Optional[Set] = None,
) -> Any:
Expand All @@ -1100,11 +1104,14 @@ def render_template(
from airflow.models.xcom_arg import XComArg

if isinstance(content, str):
if any(content.endswith(ext) for ext in self.template_ext):
# Content contains a filepath
return jinja_env.get_template(content).render(**context)
if any(content.endswith(ext) for ext in self.template_ext): # Content contains a filepath.
template = jinja_env.get_template(content)
else:
return jinja_env.from_string(content).render(**context)
template = jinja_env.from_string(content)
if self.has_dag() and self.dag.render_template_as_native_obj:
return render_template_as_native(template, context)
return render_template_to_string(template, context)

elif isinstance(content, (XComArg, DagParam)):
return content.resolve(context)

Expand Down Expand Up @@ -1133,7 +1140,7 @@ def render_template(
return content

def _render_nested_template_fields(
self, content: Any, context: Dict, jinja_env: jinja2.Environment, seen_oids: Set
self, content: Any, context: Context, jinja_env: jinja2.Environment, seen_oids: Set
) -> None:
if id(content) not in seen_oids:
seen_oids.add(id(content))
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jsonschema.exceptions import ValidationError

from airflow.exceptions import AirflowException
from airflow.utils.context import Context
from airflow.utils.types import NOTSET, ArgNotSet


Expand Down Expand Up @@ -234,7 +235,7 @@ def __init__(self, current_dag, name: str, default: Optional[Any] = None):
self._name = name
self._default = default

def resolve(self, context: Dict) -> Any:
def resolve(self, context: Context) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
default = self._default
if not self._default:
Expand Down
5 changes: 3 additions & 2 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier


Expand Down Expand Up @@ -128,7 +129,7 @@ def set_downstream(
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
self.operator.set_downstream(task_or_task_list, edge_modifier)

def resolve(self, context: Dict) -> Any:
def resolve(self, context: Context) -> Any:
"""
Pull XCom value for the existing arg. This method is run during ``op.execute()``
in respectable context.
Expand Down
15 changes: 13 additions & 2 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from collections import Counter

from sqlalchemy.orm import Session

from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.session import provide_session
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule as TR

Expand Down Expand Up @@ -82,7 +84,16 @@ def _get_dep_statuses(self, ti, session, dep_context):

@provide_session
def _evaluate_trigger_rule(
self, ti, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, session
self,
ti,
successes,
skipped,
failed,
upstream_failed,
done,
flag_upstream_failed,
*,
session: Session = NEW_SESSION,
):
"""
Yields a dependency status that indicate whether the given task instance's trigger
Expand Down
57 changes: 41 additions & 16 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,22 @@
"""Jinja2 template rendering context helper."""

import contextlib
import copy
import warnings
from typing import Any, Container, Dict, Iterable, Iterator, List, MutableMapping, Tuple

_NOT_SET: Any = object()
from typing import (
AbstractSet,
Any,
Container,
Dict,
Iterator,
List,
MutableMapping,
Optional,
Tuple,
ValuesView,
)

from airflow.utils.types import NOTSET


class VariableAccessor:
Expand All @@ -41,10 +53,10 @@ def __getattr__(self, key: str) -> Any:
def __repr__(self) -> str:
return str(self.var)

def get(self, key, default: Any = _NOT_SET) -> Any:
def get(self, key, default: Any = NOTSET) -> Any:
from airflow.models.variable import Variable

if default is _NOT_SET:
if default is NOTSET:
return Variable.get(key, deserialize_json=self._deserialize_json)
return Variable.get(key, default, deserialize_json=self._deserialize_json)

Expand Down Expand Up @@ -74,16 +86,20 @@ def get(self, key: str, default_conn: Any = None) -> Any:
return default_conn


class AirflowContextDeprecationWarning(DeprecationWarning):
"""Warn for usage of deprecated context variables in a task."""


def _create_deprecation_warning(key: str, replacements: List[str]) -> DeprecationWarning:
message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version."
if not replacements:
return DeprecationWarning(message)
return AirflowContextDeprecationWarning(message)
display_except_last = ", ".join(repr(r) for r in replacements[:-1])
if display_except_last:
message += f" Please use {display_except_last} or {replacements[-1]!r} instead."
else:
message += f" Please use {replacements[-1]!r} instead."
return DeprecationWarning(message)
return AirflowContextDeprecationWarning(message)


class Context(MutableMapping[str, Any]):
Expand All @@ -108,8 +124,10 @@ class Context(MutableMapping[str, Any]):
"yesterday_ds_nodash": [],
}

def __init__(self, context: MutableMapping[str, Any]) -> None:
self._context = context
def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None:
self._context = context or {}
if kwargs:
self._context.update(kwargs)
self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()

def __repr__(self) -> str:
Expand All @@ -124,9 +142,14 @@ def __reduce_ex__(self, protocol: int) -> Tuple[Any, ...]:
items = [(key, self[key]) for key in self._context]
return dict, (items,)

def __copy__(self) -> "Context":
new = type(self)(copy.copy(self._context))
new._deprecation_replacements = self._deprecation_replacements.copy()
return new

def __getitem__(self, key: str) -> Any:
with contextlib.suppress(KeyError):
warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]), stacklevel=2)
warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]))
with contextlib.suppress(KeyError):
return self._context[key]
raise KeyError(key)
Expand All @@ -139,7 +162,7 @@ def __delitem__(self, key: str) -> None:
self._deprecation_replacements.pop(key, None)
del self._context[key]

def __contains__(self, key: str) -> bool:
def __contains__(self, key: object) -> bool:
return key in self._context

def __iter__(self) -> Iterator[str]:
Expand All @@ -158,14 +181,16 @@ def __ne__(self, other: Any) -> bool:
return NotImplemented
return self._context != other._context

def keys(self) -> Iterable[str]:
def keys(self) -> AbstractSet[str]:
return self._context.keys()

def items(self) -> Iterable[Tuple[str, Any]]:
def items(self) -> AbstractSet[Tuple[str, Any]]:
return self._context.items()

def values(self) -> Iterable[Any]:
def values(self) -> ValuesView[Any]:
return self._context.values()

def copy_only(self, keys: Container[str]) -> "Context[str, Any]":
return type(self)({k: v for k, v in self._context.items() if k in keys})
def copy_only(self, keys: Container[str]) -> "Context":
new = type(self)({k: v for k, v in self._context.items() if k in keys})
new._deprecation_replacements = self._deprecation_replacements.copy()
return new
50 changes: 46 additions & 4 deletions airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
import re
import signal
import warnings
Expand All @@ -24,11 +25,13 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar
from urllib import parse

from flask import url_for
from jinja2 import Template
import flask
import jinja2
import jinja2.nativetypes

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils.context import Context
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -160,7 +163,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> List[T]:
def parse_template_string(template_string):
"""Parses Jinja template string."""
if "{{" in template_string: # jinja mode
return None, Template(template_string)
return None, jinja2.Template(template_string)
else:
return template_string, None

Expand Down Expand Up @@ -242,5 +245,44 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str:
'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587
"""
view = conf.get('webserver', 'dag_default_view').lower()
url = url_for(f"Airflow.{view}")
url = flask.url_for(f"Airflow.{view}")
return f"{url}?{parse.urlencode(query)}"


# The 'template' argument is typed as Any because the jinja2.Template is too
# dynamic to be effectively type-checked.
def render_template(template: Any, context: Context, *, native: bool) -> Any:
"""Render a Jinja2 template with given Airflow context.
The default implementation of ``jinja2.Template.render()`` converts the
input context into dict eagerly many times, which triggers deprecation
messages in our custom context class. This takes the implementation apart
and retain the context mapping without resolving instead.
:param template: A Jinja2 template to render.
:param context: The Airflow task context to render the template with.
:param native: If set to *True*, render the template into a native type. A
DAG can enable this with ``render_template_as_native_obj=True``.
:returns: The render result.
"""
context = copy.copy(context)
env = template.environment
if template.globals:
context.update((k, v) for k, v in template.globals.items() if k not in context)
try:
nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks))
except Exception:
env.handle_exception() # Rewrite traceback to point to the template.
if native:
return jinja2.nativetypes.native_concat(nodes)
return "".join(nodes)


def render_template_to_string(template: jinja2.Template, context: Context) -> str:
"""Shorthand to ``render_template(native=False)`` with better typing support."""
return render_template(template, context, native=False)


def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
"""Shorthand to ``render_template(native=True)`` with better typing support."""
return render_template(template, context, native=True)
21 changes: 9 additions & 12 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from itsdangerous import TimedJSONWebSignatureSerializer

from airflow.configuration import AirflowConfigException, conf
from airflow.utils.helpers import parse_template_string
from airflow.utils.context import Context
from airflow.utils.helpers import parse_template_string, render_template_to_string
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler

if TYPE_CHECKING:
Expand Down Expand Up @@ -73,23 +74,19 @@ def close(self):
if self.handler:
self.handler.close()

def _render_filename(self, ti, try_number):
def _render_filename(self, ti: "TaskInstance", try_number: int) -> str:
if self.filename_jinja_template:
if hasattr(ti, 'task'):
jinja_context = ti.get_template_context()
jinja_context['try_number'] = try_number
if hasattr(ti, "task"):
context = ti.get_template_context()
else:
jinja_context = {
'ti': ti,
'ts': ti.execution_date.isoformat(),
'try_number': try_number,
}
return self.filename_jinja_template.render(**jinja_context)
context = Context(ti=ti, ts=ti.get_dagrun().logical_date.isoformat())
context["try_number"] = try_number
return render_template_to_string(self.filename_jinja_template, context)

return self.filename_template.format(
dag_id=ti.dag_id,
task_id=ti.task_id,
execution_date=ti.execution_date.isoformat(),
execution_date=ti.get_dagrun().logical_date.isoformat(),
try_number=try_number,
)

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def breeze_test_helper(request):


def pytest_configure(config):
config.addinivalue_line("filterwarnings", "error::airflow.utils.context.AirflowContextDeprecationWarning")
config.addinivalue_line("markers", "integration(name): mark test to run with named integration")
config.addinivalue_line("markers", "backend(name): mark test to run with named backend")
config.addinivalue_line("markers", "system(name): mark test to run with named system")
Expand Down
Loading

0 comments on commit 181d60c

Please sign in to comment.