diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 81ef98d32b026..4458f2406431a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -71,7 +71,7 @@ from airflow.triggers.base import BaseTrigger from airflow.utils import timezone from airflow.utils.edgemodifier import EdgeModifier -from airflow.utils.helpers import validate_key +from airflow.utils.helpers import render_template_as_native, render_template_to_string, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session @@ -1042,7 +1042,11 @@ def __setstate__(self, state): self.__dict__ = state self._log = logging.getLogger("airflow.task.operators") - def render_template_fields(self, context: Dict, jinja_env: Optional[jinja2.Environment] = None) -> None: + def render_template_fields( + self, + context: Context, + jinja_env: Optional[jinja2.Environment] = None, + ) -> None: """ Template all attributes listed in template_fields. Note this operation is irreversible. @@ -1060,7 +1064,7 @@ def _do_render_template_fields( self, parent: Any, template_fields: Iterable[str], - context: Dict, + context: Context, jinja_env: jinja2.Environment, seen_oids: Set, ) -> None: @@ -1073,7 +1077,7 @@ def _do_render_template_fields( def render_template( self, content: Any, - context: Dict, + context: Context, jinja_env: Optional[jinja2.Environment] = None, seen_oids: Optional[Set] = None, ) -> Any: @@ -1100,11 +1104,14 @@ def render_template( from airflow.models.xcom_arg import XComArg if isinstance(content, str): - if any(content.endswith(ext) for ext in self.template_ext): - # Content contains a filepath - return jinja_env.get_template(content).render(**context) + if any(content.endswith(ext) for ext in self.template_ext): # Content contains a filepath. + template = jinja_env.get_template(content) else: - return jinja_env.from_string(content).render(**context) + template = jinja_env.from_string(content) + if self.has_dag() and self.dag.render_template_as_native_obj: + return render_template_as_native(template, context) + return render_template_to_string(template, context) + elif isinstance(content, (XComArg, DagParam)): return content.resolve(context) @@ -1133,7 +1140,7 @@ def render_template( return content def _render_nested_template_fields( - self, content: Any, context: Dict, jinja_env: jinja2.Environment, seen_oids: Set + self, content: Any, context: Context, jinja_env: jinja2.Environment, seen_oids: Set ) -> None: if id(content) not in seen_oids: seen_oids.add(id(content)) diff --git a/airflow/models/param.py b/airflow/models/param.py index 90cae38c443a7..0f6233d4088b6 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -22,6 +22,7 @@ from jsonschema.exceptions import ValidationError from airflow.exceptions import AirflowException +from airflow.utils.context import Context from airflow.utils.types import NOTSET, ArgNotSet @@ -234,7 +235,7 @@ def __init__(self, current_dag, name: str, default: Optional[Any] = None): self._name = name self._default = default - def resolve(self, context: Dict) -> Any: + def resolve(self, context: Context) -> Any: """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" default = self._default if not self._default: diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index dd08ab34fa6ad..6503106e5aa11 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Union from airflow.exceptions import AirflowException from airflow.models.baseoperator import BaseOperator from airflow.models.taskmixin import TaskMixin from airflow.models.xcom import XCOM_RETURN_KEY +from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier @@ -128,7 +129,7 @@ def set_downstream( """Proxy to underlying operator set_downstream method. Required by TaskMixin.""" self.operator.set_downstream(task_or_task_list, edge_modifier) - def resolve(self, context: Dict) -> Any: + def resolve(self, context: Context) -> Any: """ Pull XCom value for the existing arg. This method is run during ``op.execute()`` in respectable context. diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index e54fc4c901917..94c19e808afca 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -18,8 +18,10 @@ from collections import Counter +from sqlalchemy.orm import Session + from airflow.ti_deps.deps.base_ti_dep import BaseTIDep -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule as TR @@ -82,7 +84,16 @@ def _get_dep_statuses(self, ti, session, dep_context): @provide_session def _evaluate_trigger_rule( - self, ti, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, session + self, + ti, + successes, + skipped, + failed, + upstream_failed, + done, + flag_upstream_failed, + *, + session: Session = NEW_SESSION, ): """ Yields a dependency status that indicate whether the given task instance's trigger diff --git a/airflow/utils/context.py b/airflow/utils/context.py index fca55c1cb6043..d1c75bc409008 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -19,10 +19,22 @@ """Jinja2 template rendering context helper.""" import contextlib +import copy import warnings -from typing import Any, Container, Dict, Iterable, Iterator, List, MutableMapping, Tuple - -_NOT_SET: Any = object() +from typing import ( + AbstractSet, + Any, + Container, + Dict, + Iterator, + List, + MutableMapping, + Optional, + Tuple, + ValuesView, +) + +from airflow.utils.types import NOTSET class VariableAccessor: @@ -41,10 +53,10 @@ def __getattr__(self, key: str) -> Any: def __repr__(self) -> str: return str(self.var) - def get(self, key, default: Any = _NOT_SET) -> Any: + def get(self, key, default: Any = NOTSET) -> Any: from airflow.models.variable import Variable - if default is _NOT_SET: + if default is NOTSET: return Variable.get(key, deserialize_json=self._deserialize_json) return Variable.get(key, default, deserialize_json=self._deserialize_json) @@ -74,16 +86,20 @@ def get(self, key: str, default_conn: Any = None) -> Any: return default_conn +class AirflowContextDeprecationWarning(DeprecationWarning): + """Warn for usage of deprecated context variables in a task.""" + + def _create_deprecation_warning(key: str, replacements: List[str]) -> DeprecationWarning: message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version." if not replacements: - return DeprecationWarning(message) + return AirflowContextDeprecationWarning(message) display_except_last = ", ".join(repr(r) for r in replacements[:-1]) if display_except_last: message += f" Please use {display_except_last} or {replacements[-1]!r} instead." else: message += f" Please use {replacements[-1]!r} instead." - return DeprecationWarning(message) + return AirflowContextDeprecationWarning(message) class Context(MutableMapping[str, Any]): @@ -108,8 +124,10 @@ class Context(MutableMapping[str, Any]): "yesterday_ds_nodash": [], } - def __init__(self, context: MutableMapping[str, Any]) -> None: - self._context = context + def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None: + self._context = context or {} + if kwargs: + self._context.update(kwargs) self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy() def __repr__(self) -> str: @@ -124,9 +142,14 @@ def __reduce_ex__(self, protocol: int) -> Tuple[Any, ...]: items = [(key, self[key]) for key in self._context] return dict, (items,) + def __copy__(self) -> "Context": + new = type(self)(copy.copy(self._context)) + new._deprecation_replacements = self._deprecation_replacements.copy() + return new + def __getitem__(self, key: str) -> Any: with contextlib.suppress(KeyError): - warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key]), stacklevel=2) + warnings.warn(_create_deprecation_warning(key, self._deprecation_replacements[key])) with contextlib.suppress(KeyError): return self._context[key] raise KeyError(key) @@ -139,7 +162,7 @@ def __delitem__(self, key: str) -> None: self._deprecation_replacements.pop(key, None) del self._context[key] - def __contains__(self, key: str) -> bool: + def __contains__(self, key: object) -> bool: return key in self._context def __iter__(self) -> Iterator[str]: @@ -158,14 +181,16 @@ def __ne__(self, other: Any) -> bool: return NotImplemented return self._context != other._context - def keys(self) -> Iterable[str]: + def keys(self) -> AbstractSet[str]: return self._context.keys() - def items(self) -> Iterable[Tuple[str, Any]]: + def items(self) -> AbstractSet[Tuple[str, Any]]: return self._context.items() - def values(self) -> Iterable[Any]: + def values(self) -> ValuesView[Any]: return self._context.values() - def copy_only(self, keys: Container[str]) -> "Context[str, Any]": - return type(self)({k: v for k, v in self._context.items() if k in keys}) + def copy_only(self, keys: Container[str]) -> "Context": + new = type(self)({k: v for k, v in self._context.items() if k in keys}) + new._deprecation_replacements = self._deprecation_replacements.copy() + return new diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index c1041d843a79e..42a0dfa736f8e 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import copy import re import signal import warnings @@ -24,11 +25,13 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, TypeVar from urllib import parse -from flask import url_for -from jinja2 import Template +import flask +import jinja2 +import jinja2.nativetypes from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.utils.context import Context from airflow.utils.module_loading import import_string if TYPE_CHECKING: @@ -160,7 +163,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> List[T]: def parse_template_string(template_string): """Parses Jinja template string.""" if "{{" in template_string: # jinja mode - return None, Template(template_string) + return None, jinja2.Template(template_string) else: return template_string, None @@ -242,5 +245,44 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str: 'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587 """ view = conf.get('webserver', 'dag_default_view').lower() - url = url_for(f"Airflow.{view}") + url = flask.url_for(f"Airflow.{view}") return f"{url}?{parse.urlencode(query)}" + + +# The 'template' argument is typed as Any because the jinja2.Template is too +# dynamic to be effectively type-checked. +def render_template(template: Any, context: Context, *, native: bool) -> Any: + """Render a Jinja2 template with given Airflow context. + + The default implementation of ``jinja2.Template.render()`` converts the + input context into dict eagerly many times, which triggers deprecation + messages in our custom context class. This takes the implementation apart + and retain the context mapping without resolving instead. + + :param template: A Jinja2 template to render. + :param context: The Airflow task context to render the template with. + :param native: If set to *True*, render the template into a native type. A + DAG can enable this with ``render_template_as_native_obj=True``. + :returns: The render result. + """ + context = copy.copy(context) + env = template.environment + if template.globals: + context.update((k, v) for k, v in template.globals.items() if k not in context) + try: + nodes = template.root_render_func(env.context_class(env, context, template.name, template.blocks)) + except Exception: + env.handle_exception() # Rewrite traceback to point to the template. + if native: + return jinja2.nativetypes.native_concat(nodes) + return "".join(nodes) + + +def render_template_to_string(template: jinja2.Template, context: Context) -> str: + """Shorthand to ``render_template(native=False)`` with better typing support.""" + return render_template(template, context, native=False) + + +def render_template_as_native(template: jinja2.Template, context: Context) -> Any: + """Shorthand to ``render_template(native=True)`` with better typing support.""" + return render_template(template, context, native=True) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 8e5dbca04d231..cec808a417872 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -25,7 +25,8 @@ from itsdangerous import TimedJSONWebSignatureSerializer from airflow.configuration import AirflowConfigException, conf -from airflow.utils.helpers import parse_template_string +from airflow.utils.context import Context +from airflow.utils.helpers import parse_template_string, render_template_to_string from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler if TYPE_CHECKING: @@ -73,23 +74,19 @@ def close(self): if self.handler: self.handler.close() - def _render_filename(self, ti, try_number): + def _render_filename(self, ti: "TaskInstance", try_number: int) -> str: if self.filename_jinja_template: - if hasattr(ti, 'task'): - jinja_context = ti.get_template_context() - jinja_context['try_number'] = try_number + if hasattr(ti, "task"): + context = ti.get_template_context() else: - jinja_context = { - 'ti': ti, - 'ts': ti.execution_date.isoformat(), - 'try_number': try_number, - } - return self.filename_jinja_template.render(**jinja_context) + context = Context(ti=ti, ts=ti.get_dagrun().logical_date.isoformat()) + context["try_number"] = try_number + return render_template_to_string(self.filename_jinja_template, context) return self.filename_template.format( dag_id=ti.dag_id, task_id=ti.task_id, - execution_date=ti.execution_date.isoformat(), + execution_date=ti.get_dagrun().logical_date.isoformat(), try_number=try_number, ) diff --git a/tests/conftest.py b/tests/conftest.py index d3578997ad246..e9a6b61efab4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -231,6 +231,7 @@ def breeze_test_helper(request): def pytest_configure(config): + config.addinivalue_line("filterwarnings", "error::airflow.utils.context.AirflowContextDeprecationWarning") config.addinivalue_line("markers", "integration(name): mark test to run with named integration") config.addinivalue_line("markers", "backend(name): mark test to run with named backend") config.addinivalue_line("markers", "system(name): mark test to run with named system") diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 6dc2d4c377af1..db6e9e384dfc6 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1516,6 +1516,27 @@ def test_pendulum_template_dates(self, create_task_instance): assert isinstance(template_context["data_interval_start"], pendulum.DateTime) assert isinstance(template_context["data_interval_end"], pendulum.DateTime) + def test_template_render(self, create_task_instance): + ti = create_task_instance( + dag_id="test_template_render", + task_id="test_template_render_task", + schedule_interval="0 12 * * *", + ) + template_context = ti.get_template_context() + result = ti.task.render_template("Task: {{ dag.dag_id }} -> {{ task.task_id }}", template_context) + assert result == "Task: test_template_render -> test_template_render_task" + + def test_template_render_deprecated(self, create_task_instance): + ti = create_task_instance( + dag_id="test_template_render", + task_id="test_template_render_task", + schedule_interval="0 12 * * *", + ) + template_context = ti.get_template_context() + with pytest.deprecated_call(): + result = ti.task.render_template("Execution date: {{ execution_date }}", template_context) + assert result.startswith("Execution date: ") + @pytest.mark.parametrize( "content, expected_output", [