From 6f39c99d804989db2b5e1876621410ddabbf45d0 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 30 Jun 2023 01:15:42 +0200 Subject: [PATCH 1/8] Fix rendering the mapped parameters in the mapped operator Signed-off-by: Hussein Awala --- airflow/models/mappedoperator.py | 4 ++-- tests/models/test_mappedoperator.py | 33 +++++++++++++++++++---------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index dd8b49fb98e0..acf32ecc61f9 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -716,7 +716,7 @@ def render_template_fields( # in the weeds here. We don't close this session for the same reason. session = settings.Session() - mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) + mapped_kwargs, _ = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) context_update_for_unmapped(context, unmapped_task) @@ -729,6 +729,6 @@ def render_template_fields( template_fields=self.template_fields, context=context, jinja_env=jinja_env, - seen_oids=seen_oids, + seen_oids=set(), session=session, ) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index d6aef854284a..ed39ec2c20b8 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -395,17 +395,27 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect def test_mapped_render_template_fields_validating_operator(dag_maker, session): - class MyOperator(MockOperator): - def __init__(self, value, arg1, **kwargs): - assert isinstance(value, str), "value should have been resolved before unmapping" - assert isinstance(arg1, str), "value should have been resolved before unmapping" - super().__init__(arg1=arg1, **kwargs) - self.value = value + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template") + + def __init__(self, partial_template, partial_static, map_template, map_static, **kwargs): + for value in [partial_template, partial_static, map_template, map_static]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + + def execute(self, context): + pass with dag_maker(session=session): task1 = BaseOperator(task_id="op1") output1 = task1.output - mapped = MyOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand(value=output1, arg1=output1) + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand(map_template=output1, map_static=output1) dr = dag_maker.create_dagrun() ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) @@ -431,9 +441,10 @@ def __init__(self, value, arg1, **kwargs): mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) assert isinstance(mapped_ti.task, MyOperator) - assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.arg2 == "a" + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" def test_mapped_render_nested_template_fields(dag_maker, session): @@ -530,7 +541,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis @pytest.mark.parametrize( "map_index, expected", [ - pytest.param(0, "{{ ds }}", id="0"), + pytest.param(0, "2016-01-01", id="0"), pytest.param(1, 2, id="1"), ], ) From 14bd392c7e5c6f25ce25d61dfd440e9b27c1bc2e Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Tue, 8 Aug 2023 23:16:31 +0200 Subject: [PATCH 2/8] add template_in_template arg to expand method to tell Airflow whether to resolve the xcom data or not --- airflow/models/mappedoperator.py | 16 +++++++++++----- ...pre_commit_base_operator_partial_arguments.py | 1 + tests/models/test_mappedoperator.py | 11 +++++++---- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 16b125dd9f0e..8688189a873a 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -162,14 +162,16 @@ def __del__(self): task_id = f"at {hex(id(self))}" warnings.warn(f"Task {task_id} was never mapped!") - def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: + def expand(self, template_in_template=False, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") # Since the input is already checked at parse time, we can set strict # to False to skip the checks on execution. - return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) + return self._expand( + DictOfListsExpandInput(mapped_kwargs), strict=False, template_in_template=template_in_template + ) def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: from airflow.models.xcom_arg import XComArg @@ -182,7 +184,9 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) - def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: + def _expand( + self, expand_input: ExpandInput, *, strict: bool, template_in_template: bool = False + ) -> MappedOperator: from airflow.operators.empty import EmptyOperator self._expand_called = True @@ -225,6 +229,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # For classic operators, this points to expand_input because kwargs # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", + template_in_template=template_in_template, ) return op @@ -261,6 +266,7 @@ class MappedOperator(AbstractOperator): template_ext: Sequence[str] template_fields: Collection[str] template_fields_renderers: dict[str, str] + template_in_template: bool = False ui_color: str ui_fgcolor: str _is_empty: bool @@ -721,7 +727,7 @@ def render_template_fields( # in the weeds here. We don't close this session for the same reason. session = settings.Session() - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session) + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) context_update_for_unmapped(context, unmapped_task) @@ -734,6 +740,6 @@ def render_template_fields( template_fields=self.template_fields, context=context, jinja_env=jinja_env, - seen_oids=set(), + seen_oids=seen_oids if not self.template_in_template else set(), session=session, ) diff --git a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py index 84333006cf0f..cc163630424b 100755 --- a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py @@ -50,6 +50,7 @@ # Only on MappedOperator. "expand_input", "partial_kwargs", + "template_in_template", } diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 581eba085312..b6f00abe0a78 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -398,7 +398,8 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} -def test_mapped_render_template_fields_validating_operator(dag_maker, session): +@pytest.mark.parametrize("template_in_template", [True, False]) +def test_mapped_render_template_fields_validating_operator(dag_maker, session, template_in_template): class MyOperator(BaseOperator): template_fields = ("partial_template", "map_template") @@ -419,7 +420,7 @@ def execute(self, context): output1 = task1.output mapped = MyOperator.partial( task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand(map_template=output1, map_static=output1) + ).expand(template_in_template=template_in_template, map_template=output1, map_static=output1) dr = dag_maker.create_dagrun() ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) @@ -447,7 +448,9 @@ def execute(self, context): assert mapped_ti.task.partial_template == "a", "Should be templated!" assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" + assert ( + mapped_ti.task.map_template == "2016-01-01" if template_in_template else "{{ ds }}" + ), "Should be templated only when template_in_template is set to True!" assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" @@ -545,7 +548,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis @pytest.mark.parametrize( "map_index, expected", [ - pytest.param(0, "2016-01-01", id="0"), + pytest.param(0, "{{ ds }}", id="0"), pytest.param(1, 2, id="1"), ], ) From 191351cda7b51bc6d49e7fcee5ab8ccd6cd219f5 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Wed, 9 Aug 2023 00:11:32 +0200 Subject: [PATCH 3/8] fix dag serialization tests --- tests/serialization/test_dag_serialization.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 301e54a0ec99..e9fe7059cb55 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -2093,6 +2093,7 @@ def test_operator_expand_serde(): "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", + "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2145,6 +2146,7 @@ def test_operator_expand_xcomarg_serde(): "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", + "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2200,6 +2202,7 @@ def test_operator_expand_kwargs_literal_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", + "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2246,6 +2249,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", + "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2359,6 +2363,7 @@ def x(arg1, arg2, arg3): "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", + "template_in_template": False, } deserialized = SerializedBaseOperator.deserialize_operator(serialized) @@ -2451,6 +2456,7 @@ def x(arg1, arg2, arg3): "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, "_disallow_kwargs_override": strict, "_expand_input_attr": "op_kwargs_expand_input", + "template_in_template": False, } deserialized = SerializedBaseOperator.deserialize_operator(serialized) @@ -2572,6 +2578,7 @@ def operator_extra_links(self): "_task_module": "tests.serialization.test_dag_serialization", "_is_empty": False, "_is_mapped": True, + "template_in_template": False, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag) assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()] From 6fc8ca68cdaadf3b84a7d4103367b5f9c4c56c8f Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Wed, 9 Aug 2023 23:32:11 +0200 Subject: [PATCH 4/8] Revert "fix dag serialization tests" This reverts commit 191351cda7b51bc6d49e7fcee5ab8ccd6cd219f5. --- tests/serialization/test_dag_serialization.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index e9fe7059cb55..301e54a0ec99 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -2093,7 +2093,6 @@ def test_operator_expand_serde(): "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", - "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2146,7 +2145,6 @@ def test_operator_expand_xcomarg_serde(): "ui_fgcolor": "#000", "_disallow_kwargs_override": False, "_expand_input_attr": "expand_input", - "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2202,7 +2200,6 @@ def test_operator_expand_kwargs_literal_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", - "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2249,7 +2246,6 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "ui_fgcolor": "#000", "_disallow_kwargs_override": strict, "_expand_input_attr": "expand_input", - "template_in_template": False, } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -2363,7 +2359,6 @@ def x(arg1, arg2, arg3): "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, "_disallow_kwargs_override": False, "_expand_input_attr": "op_kwargs_expand_input", - "template_in_template": False, } deserialized = SerializedBaseOperator.deserialize_operator(serialized) @@ -2456,7 +2451,6 @@ def x(arg1, arg2, arg3): "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, "_disallow_kwargs_override": strict, "_expand_input_attr": "op_kwargs_expand_input", - "template_in_template": False, } deserialized = SerializedBaseOperator.deserialize_operator(serialized) @@ -2578,7 +2572,6 @@ def operator_extra_links(self): "_task_module": "tests.serialization.test_dag_serialization", "_is_empty": False, "_is_mapped": True, - "template_in_template": False, } deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag) assert deserialized_dag.task_dict["task"].operator_extra_links == [AirflowLink2()] From 7e417b345cf950a80d5dd93517898c424b6e9218 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Wed, 9 Aug 2023 23:32:20 +0200 Subject: [PATCH 5/8] Revert "add template_in_template arg to expand method to tell Airflow whether to resolve the xcom data or not" This reverts commit 14bd392c7e5c6f25ce25d61dfd440e9b27c1bc2e. --- airflow/models/mappedoperator.py | 16 +++++----------- ...pre_commit_base_operator_partial_arguments.py | 1 - tests/models/test_mappedoperator.py | 11 ++++------- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 8688189a873a..16b125dd9f0e 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -162,16 +162,14 @@ def __del__(self): task_id = f"at {hex(id(self))}" warnings.warn(f"Task {task_id} was never mapped!") - def expand(self, template_in_template=False, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: + def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") # Since the input is already checked at parse time, we can set strict # to False to skip the checks on execution. - return self._expand( - DictOfListsExpandInput(mapped_kwargs), strict=False, template_in_template=template_in_template - ) + return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator: from airflow.models.xcom_arg import XComArg @@ -184,9 +182,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) - def _expand( - self, expand_input: ExpandInput, *, strict: bool, template_in_template: bool = False - ) -> MappedOperator: + def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: from airflow.operators.empty import EmptyOperator self._expand_called = True @@ -229,7 +225,6 @@ def _expand( # For classic operators, this points to expand_input because kwargs # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", - template_in_template=template_in_template, ) return op @@ -266,7 +261,6 @@ class MappedOperator(AbstractOperator): template_ext: Sequence[str] template_fields: Collection[str] template_fields_renderers: dict[str, str] - template_in_template: bool = False ui_color: str ui_fgcolor: str _is_empty: bool @@ -727,7 +721,7 @@ def render_template_fields( # in the weeds here. We don't close this session for the same reason. session = settings.Session() - mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) + mapped_kwargs, _ = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) context_update_for_unmapped(context, unmapped_task) @@ -740,6 +734,6 @@ def render_template_fields( template_fields=self.template_fields, context=context, jinja_env=jinja_env, - seen_oids=seen_oids if not self.template_in_template else set(), + seen_oids=set(), session=session, ) diff --git a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py index cc163630424b..84333006cf0f 100755 --- a/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py @@ -50,7 +50,6 @@ # Only on MappedOperator. "expand_input", "partial_kwargs", - "template_in_template", } diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index b6f00abe0a78..581eba085312 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -398,8 +398,7 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} -@pytest.mark.parametrize("template_in_template", [True, False]) -def test_mapped_render_template_fields_validating_operator(dag_maker, session, template_in_template): +def test_mapped_render_template_fields_validating_operator(dag_maker, session): class MyOperator(BaseOperator): template_fields = ("partial_template", "map_template") @@ -420,7 +419,7 @@ def execute(self, context): output1 = task1.output mapped = MyOperator.partial( task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand(template_in_template=template_in_template, map_template=output1, map_static=output1) + ).expand(map_template=output1, map_static=output1) dr = dag_maker.create_dagrun() ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) @@ -448,9 +447,7 @@ def execute(self, context): assert mapped_ti.task.partial_template == "a", "Should be templated!" assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert ( - mapped_ti.task.map_template == "2016-01-01" if template_in_template else "{{ ds }}" - ), "Should be templated only when template_in_template is set to True!" + assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" @@ -548,7 +545,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis @pytest.mark.parametrize( "map_index, expected", [ - pytest.param(0, "{{ ds }}", id="0"), + pytest.param(0, "2016-01-01", id="0"), pytest.param(1, 2, id="1"), ], ) From 79b8441571c887f54a1766b250796f5dd6220f26 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 10 Aug 2023 01:02:10 +0200 Subject: [PATCH 6/8] Fix ListOfDictsExpandInput resolve method --- airflow/models/expandinput.py | 9 +++- airflow/models/mappedoperator.py | 4 +- tests/models/test_mappedoperator.py | 64 ++++++++++++++++++++++++++--- 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 36fb5f41650a..4f88e235a6b2 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -217,6 +217,10 @@ class ListOfDictsExpandInput(NamedTuple): value: OperatorExpandKwargsArgument + def _iter_parse_time_resolved_kwargs(self, mapping) -> Iterable[tuple[str, Sized]]: + """Generate kwargs with values available on parse-time.""" + return ((k, v) for k, v in mapping.items() if _is_parse_time_mappable(v)) + def get_parse_time_mapped_ti_count(self) -> int: if isinstance(self.value, collections.abc.Sized): return len(self.value) @@ -265,7 +269,10 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any] f"expand_kwargs() input dict keys must all be str, " f"but {key!r} is of type {_describe_type(key)}" ) - return mapping, {id(v) for v in mapping.values()} + literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs(mapping)} + resolved_oids = {id(v) for k, v in mapping.items() if k not in literal_keys} + + return mapping, resolved_oids EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 16b125dd9f0e..82dcc82aa060 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -721,7 +721,7 @@ def render_template_fields( # in the weeds here. We don't close this session for the same reason. session = settings.Session() - mapped_kwargs, _ = self._expand_mapped_kwargs(context, session) + mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) context_update_for_unmapped(context, unmapped_task) @@ -734,6 +734,6 @@ def render_template_fields( template_fields=self.template_fields, context=context, jinja_env=jinja_env, - seen_oids=set(), + seen_oids=seen_oids, session=session, ) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 581eba085312..6d4a2fbca5cf 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -20,6 +20,7 @@ import logging from collections import defaultdict from datetime import timedelta +from unittest import mock from unittest.mock import patch import pendulum @@ -400,16 +401,20 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect def test_mapped_render_template_fields_validating_operator(dag_maker, session): class MyOperator(BaseOperator): - template_fields = ("partial_template", "map_template") + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) - def __init__(self, partial_template, partial_static, map_template, map_static, **kwargs): - for value in [partial_template, partial_static, map_template, map_static]: + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: assert isinstance(value, str), "value should have been resolved before unmapping" super().__init__(**kwargs) self.partial_template = partial_template self.partial_static = partial_static self.map_template = map_template self.map_static = map_static + self.file_template = file_template def execute(self, context): pass @@ -419,7 +424,7 @@ def execute(self, context): output1 = task1.output mapped = MyOperator.partial( task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand(map_template=output1, map_static=output1) + ).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) dr = dag_maker.create_dagrun() ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) @@ -442,13 +447,62 @@ def execute(self, context): mapped_ti.map_index = 0 assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) + + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + + +def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session): + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass + + with dag_maker(session=session): + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ) + + dr = dag_maker.create_dagrun() + + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) + + assert isinstance(mapped_ti.task, MappedOperator) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) assert isinstance(mapped_ti.task, MyOperator) assert mapped_ti.task.partial_template == "a", "Should be templated!" assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" def test_mapped_render_nested_template_fields(dag_maker, session): From ccc117c4bf3d76e611555eb56266b9b6c69c5a6d Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 10 Aug 2023 20:14:14 +0200 Subject: [PATCH 7/8] remove _iter_parse_time_resolved_kwargs method --- airflow/models/expandinput.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 4f88e235a6b2..7c8742914200 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -217,10 +217,6 @@ class ListOfDictsExpandInput(NamedTuple): value: OperatorExpandKwargsArgument - def _iter_parse_time_resolved_kwargs(self, mapping) -> Iterable[tuple[str, Sized]]: - """Generate kwargs with values available on parse-time.""" - return ((k, v) for k, v in mapping.items() if _is_parse_time_mappable(v)) - def get_parse_time_mapped_ti_count(self) -> int: if isinstance(self.value, collections.abc.Sized): return len(self.value) @@ -269,7 +265,8 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any] f"expand_kwargs() input dict keys must all be str, " f"but {key!r} is of type {_describe_type(key)}" ) - literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs(mapping)} + parse_time_resolved_kwargs = ((k, v) for k, v in mapping.items() if _is_parse_time_mappable(v)) + literal_keys = {k for k, _ in parse_time_resolved_kwargs} resolved_oids = {id(v) for k, v in mapping.items() if k not in literal_keys} return mapping, resolved_oids From e895517406be7758f0dece1a51f515f4dd8f2ef7 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 10 Aug 2023 20:17:50 +0200 Subject: [PATCH 8/8] remove unnecessary step --- airflow/models/expandinput.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 7c8742914200..a9128568d37d 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -265,9 +265,8 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any] f"expand_kwargs() input dict keys must all be str, " f"but {key!r} is of type {_describe_type(key)}" ) - parse_time_resolved_kwargs = ((k, v) for k, v in mapping.items() if _is_parse_time_mappable(v)) - literal_keys = {k for k, _ in parse_time_resolved_kwargs} - resolved_oids = {id(v) for k, v in mapping.items() if k not in literal_keys} + # filter out parse time resolved values from the resolved_oids + resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} return mapping, resolved_oids