From 22d09d30dfb4e9775e969ac78a882607e1c088dd Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 18 Aug 2023 21:17:07 +0200 Subject: [PATCH] Fix rendering the mapped parameters when using `expand_kwargs` method (#32272) * Fix rendering the mapped parameters in the mapped operator Signed-off-by: Hussein Awala * add template_in_template arg to expand method to tell Airflow whether to resolve the xcom data or not * fix dag serialization tests * Revert "fix dag serialization tests" This reverts commit 191351cda7b51bc6d49e7fcee5ab8ccd6cd219f5. * Revert "add template_in_template arg to expand method to tell Airflow whether to resolve the xcom data or not" This reverts commit 14bd392c7e5c6f25ce25d61dfd440e9b27c1bc2e. * Fix ListOfDictsExpandInput resolve method * remove _iter_parse_time_resolved_kwargs method * remove unnecessary step --------- Signed-off-by: Hussein Awala (cherry picked from commit d1e6a5c48d03322dda090113134f745d1f9c34d4) --- airflow/models/expandinput.py | 5 +- tests/models/test_mappedoperator.py | 89 +++++++++++++++++++++++++---- 2 files changed, 81 insertions(+), 13 deletions(-) diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 36fb5f41650a..a9128568d37d 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -265,7 +265,10 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any] f"expand_kwargs() input dict keys must all be str, " f"but {key!r} is of type {_describe_type(key)}" ) - return mapping, {id(v) for v in mapping.values()} + # filter out parse time resolved values from the resolved_oids + resolved_oids = {id(v) for k, v in mapping.items() if not _is_parse_time_mappable(v)} + + return mapping, resolved_oids EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value. diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index a7f6d0660c76..6d4a2fbca5cf 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -20,6 +20,7 @@ import logging from collections import defaultdict from datetime import timedelta +from unittest import mock from unittest.mock import patch import pendulum @@ -399,17 +400,31 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect def test_mapped_render_template_fields_validating_operator(dag_maker, session): - class MyOperator(MockOperator): - def __init__(self, value, arg1, **kwargs): - assert isinstance(value, str), "value should have been resolved before unmapping" - assert isinstance(arg1, str), "value should have been resolved before unmapping" - super().__init__(arg1=arg1, **kwargs) - self.value = value + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass with dag_maker(session=session): task1 = BaseOperator(task_id="op1") output1 = task1.output - mapped = MyOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand(value=output1, arg1=output1) + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) dr = dag_maker.create_dagrun() ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) @@ -432,12 +447,62 @@ def __init__(self, value, arg1, **kwargs): mapped_ti.map_index = 0 assert isinstance(mapped_ti.task, MappedOperator) - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) + + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + + +def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session): + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass + + with dag_maker(session=session): + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ) + + dr = dag_maker.create_dagrun() + + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) + + assert isinstance(mapped_ti.task, MappedOperator) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) assert isinstance(mapped_ti.task, MyOperator) - assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.arg2 == "a" + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" def test_mapped_render_nested_template_fields(dag_maker, session): @@ -534,7 +599,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis @pytest.mark.parametrize( "map_index, expected", [ - pytest.param(0, "{{ ds }}", id="0"), + pytest.param(0, "2016-01-01", id="0"), pytest.param(1, 2, id="1"), ], )